diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py new file mode 100644 index 000000000..659d51bb9 --- /dev/null +++ b/megatron/data/distdata.py @@ -0,0 +1,220 @@ +import os +import numpy as np + +import torch +import torch.nn.functional as F +import torch.distributed as dist + +class DistDataError(Exception): + """Defines an empty exception to throw when some other rank hit a real exception.""" + pass + +class DistData(object): + def __init__(self, backend='gloo'): + assert backend in ['gloo', 'mpi'], f"torch.distributed backend '{backend}' is not supported, valid options are 'gloo' or 'mpi'" + + dist.init_process_group(backend, init_method="env://") + + # lookup our process rank and the group size + self.rank = dist.get_rank() + self.numranks = dist.get_world_size() + + def allassert(self, cond, msg): + """Check that cond is True on all ranks, assert with msg everywhere if not. + + To prevent deadlocks in cases where an assertion might only fail on one rank, + this executes an allreduce to ensure that if any rank finds that an assertion + has been violated, all ranks fail an assertion check. + The condition must be true on all ranks for this not to assert. + """ + alltrue = self.alltrue(cond) + assert alltrue, msg + + def allraise_if(self, err): + """Raise exception if err is not None on any rank. + + Similarly to allassert, this raises an exception on all ranks if err + is set to an exception on any rank. Rank(s) where err is not None + re-raise err as exception, and ranks where err is None raise DistDataError. + Thus all ranks raise an exception if any rank has an active exception, + which helps avoid deadlocks in cases where an exception may be raised + on a subset of ranks. + """ + alltrue = self.alltrue(err is None) + if not alltrue: + # At least one rank raised an exception. + # Re-raise the actual exception if this rank threw one. + if err is not None: + raise err + + # TODO: is there a better exception to use here? + # On other ranks, raise an "empty" exception to indicate + # that we're only failing because someone else did. + raise DistDataError + + def barrier(self): + """Globally synchronize all processes""" + dist.barrier() + + def bcast(self, val, root): + """Broadcast a scalar value from root to all ranks""" + vals = [val] + dist.broadcast_object_list(vals, src=root) + return vals[0] + + def scatterv_(self, invals: np.array, counts: list, root:int=0): + """Scatter int64 values from invals according to counts array, return received portion in a new tensor""" + + self.allassert(len(counts) == self.numranks, + f"Length of counts list {len(counts)} does not match number of ranks {self.numranks}") + + # Define list of tensors to scatter on the root. + # torch.distributed.scatter requires each tensor to be the same shape, + # so find the max size across all count values and pad. + max_size = max(counts) + scatterlist = None + if self.rank == root: + slices = list(torch.split(torch.from_numpy(invals), counts)) + scatterlist = [F.pad(s, (0, max_size - len(s))) for s in slices] + + # Receive a tensor of the max count size from the root, + # then copy values into output numpy array, which may be smaller. + recvtensor = torch.zeros(max_size, dtype=torch.int64) + dist.scatter(recvtensor, scatterlist, src=root) + return recvtensor[:counts[self.rank]] + + def alltrue(self, val): + """Returns True if all procs input True, False otherwise""" + # torch.dist does not support reductions with bool types + # so we cast to int and cast the result back to bool + tensor = torch.tensor([int(val)], dtype=torch.int32) + dist.all_reduce(tensor, op=dist.ReduceOp.BAND) + return bool(tensor[0]) + + def sum(self, val): + """Compute sum of a scalar val, and return total on all ranks.""" + tensor = torch.tensor([val]) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return tensor[0] + + def exscan(self, val: int): + """Compute prefix sum (exclusive scan) of int64 val, and return offset of each rank.""" + # torch.distributed doesn't have a scan, so fallback to allreduce + tensor = torch.zeros(self.numranks, dtype=torch.int64) + tensor[self.rank:] = val + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return int(tensor[self.rank]) - val + + def min(self, val): + """Return minimum of scalar val to all ranks.""" + tensor = torch.tensor([val]) + dist.all_reduce(tensor, op=dist.ReduceOp.MIN) + return tensor[0] + + def minrank(self, cond): + """Find first rank whose condition is True, return that rank if any, None otherwise.""" + minrank = self.numranks + if cond: + minrank = self.rank + minrank = self.min(minrank) + + if minrank < self.numranks: + return minrank + return None + + def bcast_first(self, val): + """Broadcast val from first rank where it is not None, return val if any, None otherwise""" + # Find the first rank with a valid value. + minrank = self.minrank(val is not None) + + # If there is no rank with a valid value, return None + if minrank is None: + return None + + # Otherwise broadcast the value from the first valid rank. + val = self.bcast(val, root=minrank) + return val + + def all_sum_(self, vals: np.array): + """Sums values in numpy array vals element-wise and update vals in place with final result on all ranks""" + # Builds torch.tensor with from_numpy to use same underlying memory as numpy array. + tensor = torch.from_numpy(vals) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + + def open(self, filename, truncate=None): + """Create, truncate, and open a file shared by all ranks.""" + + # Don't truncate existing file until all ranks reach this point + self.barrier() + + # We'll capture any exception in this variable + err = None + + # Rank 0 creates and truncates file. + if self.rank == 0: + try: + f = open(filename, 'wb') + + # Some file systems like GPFS deliver faster write speed + # if the file size is known before data is written to the file. + if truncate is not None: + f.truncate(truncate) + + except Exception as e: + err = e + + # Verify that rank 0 created the file + self.allraise_if(err) + + # Wait for rank 0 to open (and truncate) file, + # then have all ranks open file for writing. + if self.rank != 0: + try: + f = open(filename, 'r+b') + except Exception as e: + err = e + + # Verify that all ranks successfully opened the file + self.allraise_if(err) + + return f + + def remove(self, filename): + """Remove a shared file.""" + + # Don't remove the file until all are ready + self.barrier() + + # We'll capture any exception in this variable + err = None + + # Rank 0 removes the file if it exists. + if self.rank == 0: + try: + if os.path.exists(filename): + os.remove(filename) + except Exception as e: + err = e + + # Verify that rank 0 successfully removed the file. + self.allraise_if(err) + + def rename(self, srcfile, destfile): + """Rename a shared file.""" + + # Don't rename until all are ready + self.barrier() + + # We'll capture any exception in this variable + err = None + + # Rank 0 renames the file. + if self.rank == 0: + try: + if os.path.exists(srcfile): + os.rename(srcfile, destfile) + except Exception as e: + err = e + + # Verify that the rename succeeded + self.allraise_if(err) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index acdf36246..025e9e333 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -12,6 +12,7 @@ from functools import lru_cache import os +import stat import shutil import struct from itertools import accumulate @@ -268,6 +269,7 @@ class IndexedDatasetBuilder(object): element_sizes = { np.uint8: 1, np.int8: 1, + np.uint16: 2, np.int16: 2, np.int32: 4, np.int64: 8, @@ -275,6 +277,22 @@ class IndexedDatasetBuilder(object): np.double: 8 } + @staticmethod + def write_header(fout, dtype, numdata, numsize, numdoc): + """Writes header for cached indexed dataset to given file handle, return number of bytes written.""" + startpos = fout.tell() + + fout.write(IndexedDataset._HDR_MAGIC) + fout.write(struct.pack(' [0, 10, 30, 35] + if arr.size > 1: + arr[1:] = arr[:-1] + if arr.size > 0: + arr[0] = 0 + + +def get_pointers_with_total(sizes, elemsize, dtype): + """Return a numpy array of type np.dtype giving the byte offsets. + + Multiplies values in the sizes array by elemsize (bytes), + and then computes an exclusive scan to get byte offsets. + Returns the total number of bytes as second item in a tuple. + """ + + # scale values in sizes array by elemsize to get sizes in bytes + pointers = np.array(sizes, dtype=dtype) + pointers *= elemsize + np.cumsum(pointers, axis=0, out=pointers) + + # get total number of bytes from all sizes (last element) + bytes_last = pointers[-1] if len(sizes) > 0 else 0 + + # convert to byte offsets + exscan_from_cumsum_(pointers) + + return pointers, bytes_last + + class MMapIndexedDataset(torch.utils.data.Dataset): class Index(object): _HDR_MAGIC = b'MMIDIDX\x00\x00' + @staticmethod + def write_header(fout, dtype, numsizes, numdocs): + """Writes header for mmap indexed dataset to given file handle, return number of bytes written.""" + startpos = fout.tell() + + fout.write(MMapIndexedDataset.Index._HDR_MAGIC) + fout.write(struct.pack(' 0, "All ranks have no input files to merge" + + # Check that files are all of the same index type + indexstr = gather_files_dist_check_impltype(filelist, distctx) + + # Concatenate the data files + gather_files_dist_bin(filemain, filelist, distctx) + + # Combine index files into a single index file + if indexstr == "cached": + gather_files_dist_idx_cached(filemain, filelist, distctx) + elif indexstr == "mmap": + gather_files_dist_idx_mmap(filemain, filelist, distctx) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 4d27d94b3..3723f97bc 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -78,6 +78,13 @@ def test_preprocess_data(self): tgt_path = f"{output_prefix}_text_document.{ext}" self.assertTrue(Path(tgt_path).exists(), ) + def compare_meg_data_files(self, tgt, ref): + for ext in ["bin", "idx"]: + tgt_path = f"{tgt}.{ext}" + ref_path = f"{ref}.{ext}" + self.assertTrue(Path(tgt_path).exists(), ) + self.assertTrue(filecmp.cmp(tgt_path, ref_path, shallow=False)) + def test_process_data_microsoft(self): """We want to be stable to Microsoft version.""" src_dir = self.src_dir @@ -104,9 +111,58 @@ def test_process_data_microsoft(self): # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die execute_subprocess_async(cmd, env=self.get_env()) - for ext in ["bin", "idx"]: - tgt_path = f"{output_prefix}_text_document.{ext}" - ref_path = f"{data_dir}/meg-gpt2-openwebtext_text_document.{ext}" - self.assertTrue(Path(tgt_path).exists(), ) - self.assertTrue(filecmp.cmp(tgt_path, ref_path, shallow=False)) + self.compare_meg_data_files(f"{output_prefix}_text_document", f"{data_dir}/meg-gpt2-openwebtext_text_document") + + def test_process_data_dist_microsoft(self): + """We want to be stable to Microsoft version.""" + src_dir = self.src_dir + data_dir = f"{self.data_dir}/gpt2" + output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) + + output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext_1k" + + cmd = f""" + python -m torch.distributed.launch --nproc_per_node 2 {src_dir}/tools/preprocess_data_dist.py + --input openwebtext-10k + --count 1000 + --output-prefix {output_prefix} + --dataset-impl mmap + --tokenizer-type GPT2BPETokenizer + --merge-file {data_dir}/gpt2-tiny-merges.txt + --vocab {data_dir}/gpt2-tiny-vocab.json + --append-eod + """.split() + + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + execute_subprocess_async(cmd, env=self.get_env()) + + self.compare_meg_data_files(f"{output_prefix}_text_document", f"{data_dir}/meg-gpt2-openwebtext_text_document") + + def test_process_data_dist_serial_microsoft(self): + """We want to be stable to Microsoft version.""" + src_dir = self.src_dir + data_dir = f"{self.data_dir}/gpt2" + output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) + + output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext_1k" + + cmd = f""" + python -m torch.distributed.launch --nproc_per_node 2 {src_dir}/tools/preprocess_data_dist.py + --input openwebtext-10k + --count 1000 + --merge serial + --output-prefix {output_prefix} + --dataset-impl mmap + --tokenizer-type GPT2BPETokenizer + --merge-file {data_dir}/gpt2-tiny-merges.txt + --vocab {data_dir}/gpt2-tiny-vocab.json + --append-eod + """.split() + + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + execute_subprocess_async(cmd, env=self.get_env()) + + self.compare_meg_data_files(f"{output_prefix}_text_document", f"{data_dir}/meg-gpt2-openwebtext_text_document") diff --git a/tools/preprocess_dataset_mpi.py b/tools/preprocess_data_dist.py similarity index 61% rename from tools/preprocess_dataset_mpi.py rename to tools/preprocess_data_dist.py index 3cd366160..a2c9c0970 100644 --- a/tools/preprocess_dataset_mpi.py +++ b/tools/preprocess_data_dist.py @@ -20,26 +20,28 @@ from datasets import load_dataset dset = load_dataset('openwebtext', split='train') -The implementation can use `mpi4py` or `torch.distributed` for node communication, and it assumes that -files are written to a global file system, such that one process +The implementation uses `torch.distributed` for inter-process communication, +and it assumes that files are written to a global file system, such that one process can read a file written by another process. A list of sample index values from the source dataset are selected -by rank 0 and broadcast to all ranks. +by rank 0 and scattered to all ranks. Each process tokenizes a subset of samples and writes its output to a part file. -After all ranks have finished, rank 0 merges and deletes the part files. +After all ranks have finished, the part files are merged into a final output file. + +One may optionally use storage local to each process to store the part file. +For example, on a Linux cluster, one might write the part file to /dev/shm. To run: -mpiexec -np 320 python preprocess_dataset_mpi.py \ - --input openwebtext \ - --shuffle \ - --seed 100 \ - --output-prefix openwebtext-bert \ - --vocab bert-large-uncased-vocab.txt \ - --dataset-impl mmap \ - --tokenizer-type BertWordPieceLowerCase \ - --split-sentences +python -m torch.distributed.launch --nproc_per_node 40 --nnodes 8 \ + preprocess_data_dist.py \ + --input openwebtext \ + --output-prefix openwebtext-bert \ + --vocab bert-large-uncased-vocab.txt \ + --dataset-impl mmap \ + --tokenizer-type BertWordPieceLowerCase \ + --split-sentences """ import argparse @@ -54,7 +56,6 @@ import random import torch -import torch.distributed as dist try: import nltk nltk_available = True @@ -65,7 +66,15 @@ from datasets.utils.file_utils import OfflineModeIsEnabled from megatron.tokenizer import build_tokenizer -from megatron.data.indexed_dataset import data_file_path, index_file_path, make_builder, best_fitting_dtype +from megatron.data.indexed_dataset import data_file_path, index_file_path, make_builder, best_fitting_dtype, gather_files_dist +from megatron.data.distdata import DistData + +def msg(msg, flush=False): + timestamp = time.strftime("%Y-%m-%dT%H:%M:%S") + print(f"{timestamp}: {msg}", flush=flush) + +def msgerr(msg, flush=False): + print(f"ERROR: {msg}", flush=flush) # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): @@ -92,7 +101,7 @@ def __init__(self, args): if self.args.split_sentences: if not nltk_available: - print("NLTK is not available to split sentences.") + msgerr("NLTK is not available to split sentences.") exit() splitter = nltk.load("tokenizers/punkt/english.pickle") if self.args.keep_newlines: @@ -160,41 +169,52 @@ def get_args(): choices=['lazy', 'cached', 'mmap']) group = parser.add_argument_group(title='runtime') - group.add_argument('--mpi4py', action='store_true', - help='Assume script has been launched as an MPI job, and use MPI for communication.') - group.add_argument('--torch-backend', type=str, default='gloo', choices = ['gloo', 'mpi'], + group.add_argument('--torch-backend', type=str, default='gloo', choices=['gloo', 'mpi'], help='Select torch.distributed backend.') group.add_argument('--local_rank', type=int, default=None, help='Local rank of calling process on its node (from torch.distributed.launch).') + group.add_argument('--merge', type=str, default='parallel', choices=['parallel', 'serial', 'both'], + help=('Method to merge intermediate per-rank files into the final data files. ' + 'With "parallel", each rank writes directly to the final files, ' + 'while rank 0 copies data from all per-rank files with "serial". ' + 'A parallel merge can be faster, but for correctness, it requires the underlying file system ' + 'to support parallel write operations to a file that is shared among multiple processes. ' + 'One can choose "both" for testing purposes, in which case the final files written ' + 'by the parallel method are given an additional ".par" extension.')) + group.add_argument('--scratch', type=str, default=None, + help=('Path to local storage on compute nodes to write per-rank files before merging, like /dev/shm. ' + 'One can only use this option with a parallel merge.')) group.add_argument('--log-interval', type=int, default=30, help='Seconds between progress updates (0 to disable)') args = parser.parse_args() args.keep_empty = False - if args.tokenizer_type.lower().startswith('bert'): - if not args.split_sentences: - print("Bert tokenizer detected, are you sure you don't want to split sentences?") + # initialize our distributed environment + args.distctx = DistData(backend=args.torch_backend) - args.level = "document" - if args.split_sentences: - args.level = "sentence" + # some functions like build_tokenizer use args.rank to filter stdout messages + args.rank = args.distctx.rank + args.numranks = args.distctx.numranks # some default/dummy values for the tokenizer - args.rank = 0 args.make_vocab_size_divisible_by = 128 args.tensor_model_parallel_size = 1 args.vocab_extra_ids = 0 - # use mpi4py instead of torch.distributed if requested - args.use_mpi = False - if args.mpi4py: - try: - from mpi4py import MPI - args.MPI = MPI - args.use_mpi = True - except: - print(f"ERROR: mpi4py requested, but failed to import, falling back to torch.distributed.", flush=True) + if args.tokenizer_type.lower().startswith('bert'): + if not args.split_sentences: + if args.rank == 0: + msg("Bert tokenizer detected, are you sure you don't want to split sentences?") + + args.level = "document" + if args.split_sentences: + args.level = "sentence" + + # TODO: perhaps more user friendly to disable scratch and print a warning? + # check that serial merge is not attempted with scratch + if args.scratch is not None and args.merge != 'parallel': + raise ValueError("The --scratch option is only valid with --merge=parallel") return args @@ -202,65 +222,6 @@ def format_byterate(byterate): mbps = byterate / (1024.0 * 1024.0) return f"{mbps:0.3f} MB/s" -def init_distributed(args): - """Determine which distributed runtime to use and connect up processes""" - # select our distributed runtime (MPI or torch.distributed) - # lookup our process rank and the group size - # some functions like build_tokenizer use args.rank to filter stdout messages - if args.use_mpi: - args.mpi_comm = args.MPI.COMM_WORLD - args.rank = args.mpi_comm.Get_rank() - args.numranks = args.mpi_comm.Get_size() - else: - dist.init_process_group(args.torch_backend, init_method="env://") - args.rank = dist.get_rank() - args.numranks = dist.get_world_size() - -def barrier(args): - """Globally synchronize all processes.""" - if args.use_mpi: - args.mpi_comm.barrier() - else: - dist.barrier() - -def scatterv_(args, invals, counts, outval, root=0): - """Scatter int64 values from invals according to counts array, receive values in outval""" - assert len(counts) == args.numranks, f"Length of counts list {len(counts)} does not match number of ranks {args.numranks}" - assert outval.shape == (counts[args.rank],), f"Rank {args.rank}: output buffer is of shape {outval.shape}, expected {(counts[args.rank],)}" - - if args.use_mpi: - counts = np.array(counts) - displs = np.cumsum(counts) - counts - args.mpi_comm.Scatterv([invals, counts, displs, args.MPI.INT64_T], outval, root=root) - else: - scatterlist = None - if args.rank == root: - scatterlist = list(torch.split(torch.from_numpy(invals), counts)) - outtensor = torch.from_numpy(outval) - dist.scatter(outtensor, scatterlist, src=root) - -def all_sum_(args, vals): - """Sums values in vals element-wise and updates vals with final result on all ranks""" - if args.use_mpi: - outval = np.zeros_like(vals) - args.mpi_comm.Allreduce(vals, outval, op=args.MPI.SUM) - vals[:] = outval - else: - tensor = torch.from_numpy(vals) - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - -def all_true(args, val): - """Returns True if all procs input True, False otherwise""" - if args.use_mpi: - inval = np.array([val], dtype=np.bool_) - outval = np.zeros_like(inval) - args.mpi_comm.Allreduce(inval, outval, op=args.MPI.LAND) - return bool(outval[0]) - else: - tensor = torch.tensor([int(val)], dtype=torch.int32) - dist.all_reduce(tensor, op=dist.ReduceOp.BAND) - return bool(tensor[0]) - def load_dset(args): # Avoid downloading datasets unless explicitly requested. # We allow the user to override this behavior if they set $HF_DATASETS_OFFLINE. @@ -280,32 +241,29 @@ def load_dset(args): if args.rank != 0: logging.set_verbosity(logging.ERROR) + time_start = time.time() + # Load the specified HuggingFace dataset. # Give rank 0 a head start in case the dataset is not already cached. - success = True err = None dsetname = args.input if args.rank == 0: - print(f"Opening dataset {dsetname}") + msg(f"Opening dataset {dsetname}") try: dset = load_dataset(dsetname, split=args.split, keep_in_memory=None) except OfflineModeIsEnabled as e: - print(f"ERROR: Cannot download '{dsetname}' since running in offline mode.") - print(f"ERROR: If the dataset is large, it may be more efficient to download with a single process:") - print(f"ERROR: from datasets import load_dataset") - print(f"ERROR: dset = load_dataset('{dsetname}')") - print(f"ERROR: Alternatively, one can force this script to download by setting $HF_DATASETS_OFFLINE=0", flush=True) - success = False + msgerr(f"Cannot download '{dsetname}' since running in offline mode.") + msgerr(f"If the dataset is large, it may be more efficient to download with a single process:") + msgerr(f" from datasets import load_dataset") + msgerr(f" dset = load_dataset('{dsetname}')") + msgerr(f"Alternatively, one can force this script to download by setting $HF_DATASETS_OFFLINE=0", flush=True) err = e except Exception as e: - print("ERROR: Unexpected error:", sys.exc_info()[0], flush=True) - success = False + msgerr(f"Unexpected error: {sys.exc_info()[0]}", flush=True) err = e # determine whether rank 0 succeeded in loading the dataset - success = all_true(args, success) - if not success: - return None, err + args.distctx.allraise_if(err) # Rank 0 succeeded, attempt to load dataset on all other ranks. # This should load from cache now. @@ -314,18 +272,17 @@ def load_dset(args): dset = load_dataset(dsetname, split=args.split, keep_in_memory=None) except Exception as e: # this print might be noisy, but better than nothing - print("ERROR: Unexpected error:", sys.exc_info()[0], flush=True) - success = False + msgerr(f"Unexpected error: {sys.exc_info()[0]}", flush=True) err = e # verify that all ranks loaded the dataset - success = all_true(args, success) - if not success: - if args.rank == 0: - print(f"ERROR: At least one process failed to load {dsetname}", flush=True) - return None, err + args.distctx.allraise_if(err) + + time_end = time.time() + if args.rank == 0: + msg(f"Seconds to load dataset: {time_end - time_start}", flush=True) - return dset, err + return dset def get_num_samples(args, dset_size): """Given a dataset size and optional count argument, return number of samples to process.""" @@ -342,6 +299,7 @@ def select_sample_list(args, dset_size): # create sample index list on rank 0, # optionally shuffle the list, # and optionally limit the sample count + time_select = time.time() idxlist = None if args.rank == 0: # generate a list of all index values @@ -360,12 +318,19 @@ def select_sample_list(args, dset_size): # get a list of the number of elements each rank will hold counts = get_proc_counts(num_samples, args.numranks) - # allocate space to hold its portion of the list - idx = np.zeros(counts[args.rank], np.int64) - # scatter sample index values from rank 0 to all procs # based on distribution defined in counts list - scatterv_(args, idxlist, counts, idx, root=0) + time_bcast = time.time() + idx = args.distctx.scatterv_(idxlist, counts, root=0) + + args.distctx.barrier() + time_end = time.time() + if args.rank == 0: + msg(f"Select index stats:") + msg(f" Shuffle: {args.shuffle}") + msg(f" Seconds to select: {time_bcast - time_select}") + msg(f" Seconds to broadcast: {time_end - time_bcast}") + msg(f" Seconds total: {time_end - time_select}", flush=True) return idx @@ -376,6 +341,11 @@ def get_proc_counts(num, num_ranks): def get_filename(args, key, rank=None): pathname = args.output_prefix + # redirect per-rank file to scratch dir if defined + if args.scratch is not None and rank is not None: + basename = os.path.basename(pathname) + pathname = os.path.join(args.scratch, basename) + if rank is not None: filename = f"{pathname}_{key}_{args.level}_{rank}" else: @@ -384,7 +354,7 @@ def get_filename(args, key, rank=None): return filename def rank_files_write(args, dset, idx, encoder): - tokenize_start = time.time() + time_start = time.time() # compute total number of samples we'e processing num_samples = get_num_samples(args, len(dset)) @@ -394,13 +364,13 @@ def rank_files_write(args, dset, idx, encoder): dset_stats = np.zeros(3, dtype=np.int64) # docs, sentences, bytes # we'll set this to false on any problem - success = True err = None + times = np.zeros(3, dtype=np.float32) # read, tokenize, write try: # create data file for each rank if args.rank == 0: - print(f"Vocab size: {args.vocab_size}") - print(f"Output prefix: {args.output_prefix}") + msg(f"Vocab size: {args.vocab_size}") + msg(f"Output prefix: {args.output_prefix}") output_bin_files = {} output_idx_files = {} builders = {} @@ -408,9 +378,10 @@ def rank_files_write(args, dset, idx, encoder): filebase = get_filename(args, key, args.rank) output_bin_files[key] = data_file_path(filebase) output_idx_files[key] = index_file_path(filebase) + best_dtype = best_fitting_dtype(args.vocab_size) if args.dataset_impl == "mmap" else None builders[key] = make_builder(output_bin_files[key], impl=args.dataset_impl, - dtype=best_fitting_dtype(args.vocab_size)) + dtype=best_dtype) # each rank tokenizes its samples and writes its own file progress_next = time.time() + float(args.log_interval) @@ -418,10 +389,13 @@ def rank_files_write(args, dset, idx, encoder): sample_id = int(i) for key in args.columns: # tokenize text for the given sample index + start_read = time.time() text = dset[sample_id][key] + start_encode = time.time() doc, bytes_processed = encoder.encode_text(text) # add tokenized sequence to our data file + start_write = time.time() for key, sentences in doc.items(): for sentence in sentences: builders[key].add_item(torch.IntTensor(sentence)) @@ -429,22 +403,26 @@ def rank_files_write(args, dset, idx, encoder): dset_stats[0] += 1 dset_stats[1] += len(sentences) dset_stats[2] += bytes_processed + end_write = time.time() + + times[0] += start_encode - start_read + times[1] += start_write - start_encode + times[2] += end_write - start_write if args.rank == 0 and args.log_interval > 0 and time.time() > progress_next: current = time.time() progress_next = current + float(args.log_interval) - elapsed = current - tokenize_start - timestamp = time.strftime("%Y-%m-%dT%H:%M:%S") + elapsed = current - time_start docs = dset_stats[0] * args.numranks percent = docs / num_samples * 100.0 docrate = docs / elapsed if elapsed > 0.0 else 0.0 mbs = dset_stats[2] * args.numranks / elapsed / 1024 / 1024 if elapsed > 0.0 else 0.0 secs_left = int((num_samples - docs) / docrate if docrate > 0.0 else 0.0) - print(f"{timestamp}: Processed (estimated) {docs} of {num_samples} docs ({percent:0.2f}%),", - f"{docrate:0.3f} docs/s, {mbs:0.3f} MB/s,", - f"{secs_left} secs left ...", - flush=True) + msg(f"Processed (estimated) {docs} of {num_samples} docs ({percent:0.2f}%) in {int(elapsed)} secs, " + f"{docrate:0.3f} docs/s, {mbs:0.3f} MB/s, " + f"{secs_left} secs left ...", + flush=True) # finalize file of each rank for key in args.columns: @@ -452,40 +430,80 @@ def rank_files_write(args, dset, idx, encoder): del builders[key] # file closed in __del__ except Exception as e: # caught an exception, assume our file is invalid - success = False err = e # In case rank 0 finishes early and stops printing progress messages, # inform user that it's waiting for other ranks to finish. if args.rank == 0 and args.log_interval > 0: - timestamp = time.strftime("%Y-%m-%dT%H:%M:%S") - print(f"{timestamp}: Waiting for ranks to finalize files ...", flush=True) + msg(f"Waiting for ranks to finalize files ...", flush=True) # wait for all ranks to finish their files - barrier(args) - tokenize_end = time.time() + args.distctx.barrier() + time_end = time.time() # compute total stats across all processes - all_sum_(args, dset_stats) + args.distctx.all_sum_(times) + args.distctx.all_sum_(dset_stats) if args.rank == 0: - secs = tokenize_end - tokenize_start + secs = time_end - time_start docrate = dset_stats[0] / secs if secs > 0.0 else 0.0 sentrate = dset_stats[1] / secs if secs > 0.0 else 0.0 byterate = dset_stats[2] / secs if secs > 0.0 else 0.0 - print("Tokenize stats:", secs) - print(f" Seconds to tokenize: {secs}") - print(f" {dset_stats[0]} docs {docrate} docs/sec") - print(f" {dset_stats[1]} sents {sentrate} sents/sec") - print(f" {dset_stats[2]} bytes {format_byterate(byterate)}") - - # allreduce to check whether all ranks wrote their part successfully - success = all_true(args, success) - return success, err - -def rank_files_merge(args): - # rank 0 merges all per-rank files + secs_read_per_sample = times[0] / dset_stats[0] if dset_stats[0] > 0 else 0.0 + secs_encode_per_sample = times[1] / dset_stats[0] if dset_stats[0] > 0 else 0.0 + secs_write_per_sample = times[2] / dset_stats[0] if dset_stats[0] > 0 else 0.0 + msg("Process stats:") + msg(f" Seconds to process: {secs}") + msg(f" {dset_stats[0]} docs {docrate} docs/sec") + msg(f" {dset_stats[1]} sents {sentrate} sents/sec") + msg(f" {dset_stats[2]} bytes {format_byterate(byterate)}") + msg(f" Total read seconds {times[0]}, {secs_read_per_sample} sec/sample") + msg(f" Total encode seconds {times[1]}, {secs_encode_per_sample} sec/sample") + msg(f" Total write seconds {times[2]}, {secs_write_per_sample} sec/sample") + + # check whether all ranks wrote their part successfully + args.distctx.allraise_if(err) + +def rank_files_merge_parallel(args): + """Each process directly writes its portion of the data from its per-rank file into the final file.""" + merge_start = time.time() + numbytes = np.zeros(1, dtype=np.int64) + for key in args.columns: + # merge the per-rank file from each process into a single shared file + filemain = get_filename(args, key) + filerank = get_filename(args, key, args.rank) + gather_files_dist(filemain, [filerank], args.distctx) + + # total up bytes read during the merge + binfilerank = data_file_path(filerank) + idxfilerank = index_file_path(filerank) + numbytes[0] += os.stat(binfilerank)[stat.ST_SIZE] + numbytes[0] += os.stat(idxfilerank)[stat.ST_SIZE] + + # If user want to use both a parallel and serial merge (for testing), + # rename the parallel output files so that the serial merge does not clobber them. + if args.merge == 'both' and args.rank == 0: + binfilemain = data_file_path(filemain) + idxfilemain = index_file_path(filemain) + os.rename(binfilemain, binfilemain + ".par") + os.rename(idxfilemain, idxfilemain + ".par") + + # Total up number of bytes read across all ranks, + # and wait on all ranks before stopping the timer. + args.distctx.all_sum_(numbytes) + merge_end = time.time() + if args.rank == 0: + secs = merge_end - merge_start + byterate = numbytes[0] / secs if secs > 0.0 else 0.0 + msg("Parallel merge stats:") + msg(f" Scratch: {args.scratch}") + msg(f" Seconds to merge: {secs}") + msg(f" {int(numbytes)} bytes {format_byterate(byterate)}") + +def rank_files_merge_serial(args): + """Rank 0 merges data from all per-rank files into the final file.""" if args.rank == 0: - print("Merging rank files ...", flush=True) + msg("Merging rank files ...", flush=True) merge_start = time.time() numbytes = 0 @@ -497,25 +515,25 @@ def rank_files_merge(args): filebase = get_filename(args, key) output_bin_files[key] = data_file_path(filebase) output_idx_files[key] = index_file_path(filebase) + best_dtype = best_fitting_dtype(args.vocab_size) if args.dataset_impl == "mmap" else None builders[key] = make_builder(output_bin_files[key], impl=args.dataset_impl, - dtype=best_fitting_dtype(args.vocab_size)) + dtype=best_dtype) # merge all ranks into one file for rank in range(args.numranks): for key in args.columns: infile = get_filename(args, key, rank) - - print(f"Merging file {infile}", flush=True) builders[key].merge_file_(infile) # sum up the number of merged bytes binfile = data_file_path(infile) - filesize = os.stat(binfile)[stat.ST_SIZE] - numbytes += filesize + idxfile = index_file_path(infile) + numbytes += os.stat(binfile)[stat.ST_SIZE] + numbytes += os.stat(idxfile)[stat.ST_SIZE] # finalize the merged file - print("Finalizing merged file ...", flush=True) + msg("Finalizing merged file ...", flush=True) for key in args.columns: builders[key].finalize(output_idx_files[key]) del builders[key] # file closed in __del__ @@ -523,18 +541,31 @@ def rank_files_merge(args): merge_end = time.time() secs = merge_end - merge_start byterate = numbytes / secs if secs > 0.0 else 0.0 - print(f"Merged {args.numranks} files into {args.output_prefix}") - print("Merge stats:") - print(f" Seconds to merge: {secs}") - print(f" {numbytes} bytes {format_byterate(byterate)}") + msg(f"Merged {args.numranks} files into {args.output_prefix}") + msg("Serial merge stats:") + msg(f" Seconds to merge: {secs}") + msg(f" {numbytes} bytes {format_byterate(byterate)}") # hold everyone until rank 0 is done - barrier(args) + args.distctx.barrier() + +def rank_files_merge(args): + # use parallel merge if asked + if args.merge in ['parallel', 'both']: + rank_files_merge_parallel(args) + + # if using node-local storage, skip sequential merge + if args.scratch is not None: + return + + # can fall back to a serial merge + if args.merge in ['serial', 'both']: + rank_files_merge_serial(args) def rank_files_delete(args): # delete per-rank files if args.rank == 0: - print("Deleting rank files ...", flush=True) + msg("Deleting rank files ...", flush=True) for key in args.columns: filebase = get_filename(args, key, args.rank) @@ -548,24 +579,17 @@ def rank_files_delete(args): os.remove(idxfile) # hold everyone until all are done - barrier(args) + args.distctx.barrier() def main(): args = get_args() startup_start = time.time() - # connect processes and cache our rank and number of procs in args - init_distributed(args) - # load the dataset - dset, err = load_dset(args) - if dset is None: - if err is not None: - raise err - return + dset = load_dset(args) if args.rank == 0: print(dset) - print("Selecting features:", args.columns) + msg(f"Processing features: {args.columns}") # create sample index list, # optionally shuffle the list, @@ -579,26 +603,26 @@ def main(): args.vocab_size = encoder.tokenizer.vocab_size # wait for all ranks before stopping timer - barrier(args) + args.distctx.barrier() startup_end = time.time() if args.rank == 0: - print("Seconds to startup:", startup_end - startup_start) + msg(f"Seconds to startup: {startup_end - startup_start}") - # have each rank write its file, returns False if any rank had a problem - success, err = rank_files_write(args, dset, idx, encoder) - if not success: + # have each rank write its file, + # all ranks should raise an exception if any rank has a problem + try: + rank_files_write(args, dset, idx, encoder) + except Exception as e: + # If any process fails, we skip the merge since the resulting file would be invalid. + # We still delete files to clean up, since those might be invalid anyway. if args.rank == 0: - # If any process fails, we skip the merge since the resulting file would be invalid. - # We still delete files to clean up, since those might be invalid anyway. - print(f"ERROR: At least one process failed to write its file, skipping merge and cleaning up", flush=True) + msgerr(f"At least one process failed to write its file, skipping merge and cleaning up", flush=True) # delete per-rank files, do this even on error rank_files_delete(args) - # raise exception caught during write phase - if err is not None: - raise err - return + # re-raise exception caught during write phase + raise e # all ranks were successful writing their file, merge them into one rank_files_merge(args) @@ -606,5 +630,10 @@ def main(): # delete per-rank files rank_files_delete(args) + end_time = time.time() + if args.rank == 0: + msg(f"Runtime: {end_time - startup_start} secs", flush=True) + msg(f"Done") + if __name__ == '__main__': main()