diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 000000000..7d8fa3054 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,66 @@ +name: do-the-job +on: workflow_dispatch +jobs: + start-runner: + name: Start self-hosted EC2 runner + runs-on: ubuntu-latest + outputs: + label: ${{ steps.start-ec2-runner.outputs.label }} + ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }} + steps: + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-2 + - name: Start EC2 runner + id: start-ec2-runner + uses: machulav/ec2-github-runner@v2 + with: + mode: start + github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} + ec2-image-id: ami-0c7563df049759e8b + ec2-instance-type: p3.8xlarge + subnet-id: subnet-3502b45e + security-group-id: sg-e8f46d9d + aws-resource-tags: > # optional, requires additional permissions + [ + {"Key": "Name", "Value": "ec2-github-runner"}, + {"Key": "GitHubRepository", "Value": "${{ github.repository }}"} + ] + do-the-job: + name: Do the job on the runner + needs: start-runner # required to start the main job when the runner is ready + runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner + steps: + - name: Checkout + uses: actions/checkout@v2 + - name: Hello World + run: echo 'Hello World!' + - name: Test + run: pwd; ls -l + - name: GPU Test + run: nvidia-smi + + stop-runner: + name: Stop self-hosted EC2 runner + needs: + - start-runner # required to get output from the start-runner job + - do-the-job # required to wait when the main job is done + runs-on: ubuntu-latest + if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs + steps: + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + aws-region: us-east-2 + - name: Stop EC2 runner + uses: machulav/ec2-github-runner@v2 + with: + mode: stop + github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }} + label: ${{ needs.start-runner.outputs.label }} + ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} diff --git a/.gitignore b/.gitignore index 18e7deab8..dfaa172e9 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ # tests # megatron autogenerated indices tests/data/*/*npy +tests/tools/openwebtext-1000.jsonl # macOS .DS_Store diff --git a/megatron/arguments.py b/megatron/arguments.py index 326c948ee..de16fd5b3 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -22,6 +22,7 @@ import deepspeed from megatron.enums import PositionEmbeddingType +import megatron def parse_args(extra_args_provider=None, defaults={}, @@ -313,6 +314,10 @@ def _add_network_size_args(parser): default=PositionEmbeddingType.absolute, help='Define position embedding type ("absolute" | "rotary"). "absolute" by default.' ) + group.add_argument('--glu-activation', type=str, + choices=megatron.model.glu_activations.GLU_ACTIVATIONS.keys(), + help='GLU activations to use.' + ) return parser @@ -452,6 +457,8 @@ def _add_training_args(parser): help='Run optimizer on CPU') group.add_argument('--cpu_torch_adam', action='store_true', help='Use Torch Adam as optimizer on CPU.') + group.add_argument('--codecarbon-dir', type=str, default=None, + help='Write CodeCarbon logs to this directory.') return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 829fb1101..f7328dcbb 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -22,6 +22,7 @@ import torch +from megatron.global_vars import codecarbon_tracker_flush from megatron import (get_args, mpu, print_rank_0, @@ -135,7 +136,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint() - + # Optimizer stuff. if not args.no_save_optim: if optimizer is not None: @@ -183,6 +184,11 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): if torch.distributed.is_initialized(): torch.distributed.barrier() + # since the code can be exited or aborted in various places we use the checkpoint saving as + # a save saving point for the codecarbon tracker. If the program doesn't run to its normal + # end, then only the data since the last saved checkpoint will be lost. + codecarbon_tracker_flush() + def _transpose_first_dim(t, num_splits, num_splits_first, model): input_shape = t.size() # We use a self_attention module but the values extracted aren't @@ -417,7 +423,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True def load_biencoder_checkpoint(model, only_query_model=False, only_context_model=False, custom_load_path=None): """ - selectively load retrieval models for indexing/retrieving + selectively load retrieval models for indexing/retrieving from saved checkpoints """ diff --git a/megatron/data/distdata.py b/megatron/data/distdata.py index b7549bed5..659d51bb9 100644 --- a/megatron/data/distdata.py +++ b/megatron/data/distdata.py @@ -1,6 +1,8 @@ +import os import numpy as np import torch +import torch.nn.functional as F import torch.distributed as dist class DistDataError(Exception): @@ -8,35 +10,36 @@ class DistDataError(Exception): pass class DistData(object): - def __init__(self, backend='gloo', use_mpi4py=False): - # use mpi4py instead of torch.distributed if requested - self.mpi4py = None - if use_mpi4py: - try: - from mpi4py import MPI - self.mpi4py = MPI - except: - #print(f"ERROR: mpi4py requested, but failed to import, falling back to torch.distributed.", flush=True) - pass + def __init__(self, backend='gloo'): + assert backend in ['gloo', 'mpi'], f"torch.distributed backend '{backend}' is not supported, valid options are 'gloo' or 'mpi'" + + dist.init_process_group(backend, init_method="env://") # lookup our process rank and the group size - if self.mpi4py is not None: - self.comm = self.mpi4py.COMM_WORLD - self.rank = self.comm.Get_rank() - self.numranks = self.comm.Get_size() - else: - assert backend in ['gloo', 'mpi'], f"torch.distributed backend '{backend}' is not supported, valid options are 'gloo' or 'mpi'" - dist.init_process_group(backend, init_method="env://") - self.rank = dist.get_rank() - self.numranks = dist.get_world_size() + self.rank = dist.get_rank() + self.numranks = dist.get_world_size() def allassert(self, cond, msg): - """Check that cond is True on all ranks, assert with msg everywhere if not.""" + """Check that cond is True on all ranks, assert with msg everywhere if not. + + To prevent deadlocks in cases where an assertion might only fail on one rank, + this executes an allreduce to ensure that if any rank finds that an assertion + has been violated, all ranks fail an assertion check. + The condition must be true on all ranks for this not to assert. + """ alltrue = self.alltrue(cond) assert alltrue, msg def allraise_if(self, err): - """Raise exception if err is not None on any rank.""" + """Raise exception if err is not None on any rank. + + Similarly to allassert, this raises an exception on all ranks if err + is set to an exception on any rank. Rank(s) where err is not None + re-raise err as exception, and ranks where err is None raise DistDataError. + Thus all ranks raise an exception if any rank has an active exception, + which helps avoid deadlocks in cases where an exception may be raised + on a subset of ranks. + """ alltrue = self.alltrue(err is None) if not alltrue: # At least one rank raised an exception. @@ -51,114 +54,62 @@ def allraise_if(self, err): def barrier(self): """Globally synchronize all processes""" - if self.mpi4py is not None: - self.comm.barrier() - else: - dist.barrier() + dist.barrier() def bcast(self, val, root): """Broadcast a scalar value from root to all ranks""" - if self.mpi4py is not None: - return self.comm.bcast(val, root=root) - else: - vals = [val] - dist.broadcast_object_list(vals, src=root) - return vals[0] - - def bcast_list(self, vals, root=0): - """Broadcast list of vals from root to all ranks, returns newly allocated list""" - if self.mpi4py is not None: - return self.comm.bcast(vals, root=root) - else: - # broadcast length of vals list - length = [len(vals)] - dist.broadcast_object_list(length, src=root) - - # allocate a tensor of appropriate size - # initialize tensor with list values on root - if self.rank == root: - tvals = torch.tensor(vals, dtype=torch.int64) - else: - tvals = torch.zeros(length[0], dtype=torch.int64) - - # broadcast tensor from root, and return as a new list - dist.broadcast(tvals, src=root) - return tvals.tolist() - - def scatterv_(self, invals, counts, outval, root=0): - """Scatter int64 values from invals according to counts array, receive values in outval""" + vals = [val] + dist.broadcast_object_list(vals, src=root) + return vals[0] + + def scatterv_(self, invals: np.array, counts: list, root:int=0): + """Scatter int64 values from invals according to counts array, return received portion in a new tensor""" self.allassert(len(counts) == self.numranks, f"Length of counts list {len(counts)} does not match number of ranks {self.numranks}") - self.allassert(outval.shape == (counts[self.rank],), - f"Rank {self.rank}: output buffer is of shape {outval.shape}, expected {(counts[self.rank],)}") - - self.allassert(outval.dtype == np.int64, - f"Requires outval to be of type numpy.int64") - - if self.mpi4py is not None: - counts = np.array(counts) - displs = np.cumsum(counts) - counts - self.comm.Scatterv([invals, counts, displs, self.mpi4py.INT64_T], outval, root=root) - else: - scatterlist = None - if self.rank == root: - scatterlist = list(torch.split(torch.from_numpy(invals), counts)) - outtensor = torch.from_numpy(outval) - dist.scatter(outtensor, scatterlist, src=root) + # Define list of tensors to scatter on the root. + # torch.distributed.scatter requires each tensor to be the same shape, + # so find the max size across all count values and pad. + max_size = max(counts) + scatterlist = None + if self.rank == root: + slices = list(torch.split(torch.from_numpy(invals), counts)) + scatterlist = [F.pad(s, (0, max_size - len(s))) for s in slices] + + # Receive a tensor of the max count size from the root, + # then copy values into output numpy array, which may be smaller. + recvtensor = torch.zeros(max_size, dtype=torch.int64) + dist.scatter(recvtensor, scatterlist, src=root) + return recvtensor[:counts[self.rank]] def alltrue(self, val): """Returns True if all procs input True, False otherwise""" - if self.mpi4py is not None: - inval = np.array([val], dtype=np.bool_) - outval = np.zeros_like(inval) - self.comm.Allreduce(inval, outval, op=self.mpi4py.LAND) - return bool(outval[0]) - else: - # torch.dist does not support reductions with bool types - # so we cast to int and cast the result back to bool - tensor = torch.tensor([int(val)], dtype=torch.int32) - dist.all_reduce(tensor, op=dist.ReduceOp.BAND) - return bool(tensor[0]) + # torch.dist does not support reductions with bool types + # so we cast to int and cast the result back to bool + tensor = torch.tensor([int(val)], dtype=torch.int32) + dist.all_reduce(tensor, op=dist.ReduceOp.BAND) + return bool(tensor[0]) def sum(self, val): """Compute sum of a scalar val, and return total on all ranks.""" - if self.mpi4py is not None: - insize = np.array([val], dtype=np.int64) - outsize = np.zeros_like(insize) - self.comm.Allreduce(insize, outsize, op=self.mpi4py.SUM) - return outsize[0] - else: - tensor = torch.tensor([val]) - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - return tensor[0] - - def exscan(self, val): - """Compute prefix sum (exclusive scan) of scalar val, and return offset of each rank.""" - if self.mpi4py is not None: - insize = np.array([val], dtype=np.int64) - outsize = np.zeros_like(insize) - self.comm.Scan(insize, outsize, op=self.mpi4py.SUM) - return outsize[0] - insize[0] - else: - # torch.distributed doesn't have a scan, so fallback to allreduce - tensor = torch.zeros(self.numranks, dtype=torch.int64) - tensor[self.rank:] = val - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - return int(tensor[self.rank]) - val + tensor = torch.tensor([val]) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return tensor[0] + + def exscan(self, val: int): + """Compute prefix sum (exclusive scan) of int64 val, and return offset of each rank.""" + # torch.distributed doesn't have a scan, so fallback to allreduce + tensor = torch.zeros(self.numranks, dtype=torch.int64) + tensor[self.rank:] = val + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return int(tensor[self.rank]) - val def min(self, val): """Return minimum of scalar val to all ranks.""" - if self.mpi4py is not None: - insize = np.array([val], dtype=np.int64) - outsize = np.zeros_like(insize) - self.comm.Allreduce(insize, outsize, op=self.mpi4py.MIN) - return outsize[0] - else: - tensor = torch.tensor([val]) - dist.all_reduce(tensor, op=dist.ReduceOp.MIN) - return tensor[0] + tensor = torch.tensor([val]) + dist.all_reduce(tensor, op=dist.ReduceOp.MIN) + return tensor[0] def minrank(self, cond): """Find first rank whose condition is True, return that rank if any, None otherwise.""" @@ -184,17 +135,13 @@ def bcast_first(self, val): val = self.bcast(val, root=minrank) return val - def all_sum_(self, vals): + def all_sum_(self, vals: np.array): """Sums values in numpy array vals element-wise and update vals in place with final result on all ranks""" - if self.mpi4py is not None: - outval = np.zeros_like(vals) - self.comm.Allreduce(vals, outval, op=self.mpi4py.SUM) - vals[:] = outval - else: - tensor = torch.from_numpy(vals) - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - - def open(self, filename): + # Builds torch.tensor with from_numpy to use same underlying memory as numpy array. + tensor = torch.from_numpy(vals) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + + def open(self, filename, truncate=None): """Create, truncate, and open a file shared by all ranks.""" # Don't truncate existing file until all ranks reach this point @@ -207,6 +154,12 @@ def open(self, filename): if self.rank == 0: try: f = open(filename, 'wb') + + # Some file systems like GPFS deliver faster write speed + # if the file size is known before data is written to the file. + if truncate is not None: + f.truncate(truncate) + except Exception as e: err = e @@ -225,3 +178,43 @@ def open(self, filename): self.allraise_if(err) return f + + def remove(self, filename): + """Remove a shared file.""" + + # Don't remove the file until all are ready + self.barrier() + + # We'll capture any exception in this variable + err = None + + # Rank 0 removes the file if it exists. + if self.rank == 0: + try: + if os.path.exists(filename): + os.remove(filename) + except Exception as e: + err = e + + # Verify that rank 0 successfully removed the file. + self.allraise_if(err) + + def rename(self, srcfile, destfile): + """Rename a shared file.""" + + # Don't rename until all are ready + self.barrier() + + # We'll capture any exception in this variable + err = None + + # Rank 0 renames the file. + if self.rank == 0: + try: + if os.path.exists(srcfile): + os.rename(srcfile, destfile) + except Exception as e: + err = e + + # Verify that the rename succeeded + self.allraise_if(err) diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 1370af411..025e9e333 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -12,6 +12,7 @@ from functools import lru_cache import os +import stat import shutil import struct from itertools import accumulate @@ -351,6 +352,38 @@ def _warmup_mmap_file(path): pass +def exscan_from_cumsum_(arr): + # given an array holding the result of an inclusive scan (cumsum), + # convert to an exclusive scan (shift to the right) + # [10, 30, 35, 50] --> [0, 10, 30, 35] + if arr.size > 1: + arr[1:] = arr[:-1] + if arr.size > 0: + arr[0] = 0 + + +def get_pointers_with_total(sizes, elemsize, dtype): + """Return a numpy array of type np.dtype giving the byte offsets. + + Multiplies values in the sizes array by elemsize (bytes), + and then computes an exclusive scan to get byte offsets. + Returns the total number of bytes as second item in a tuple. + """ + + # scale values in sizes array by elemsize to get sizes in bytes + pointers = np.array(sizes, dtype=dtype) + pointers *= elemsize + np.cumsum(pointers, axis=0, out=pointers) + + # get total number of bytes from all sizes (last element) + bytes_last = pointers[-1] if len(sizes) > 0 else 0 + + # convert to byte offsets + exscan_from_cumsum_(pointers) + + return pointers, bytes_last + + class MMapIndexedDataset(torch.utils.data.Dataset): class Index(object): _HDR_MAGIC = b'MMIDIDX\x00\x00' @@ -385,13 +418,7 @@ def _get_pointers(sizes, npdtype): """ # compute element sizes in bytes - bytesizes = np.array(sizes, dtype=npdtype) - bytesizes *= dtype().itemsize - - # exclusive scan to get byte offsets - pointers = np.cumsum(bytesizes, axis=0) - pointers -= bytesizes - + pointers, _ = get_pointers_with_total(sizes, dtype().itemsize, npdtype) return pointers def write(self, sizes, doc_idx): @@ -432,21 +459,21 @@ def __init__(self, path, skip_warmup=False): offset = stream.tell() if not skip_warmup: -# print_rank_0(" warming up index mmap file...") + print_rank_0(" warming up index mmap file...") _warmup_mmap_file(path) self._bin_buffer_mmap = np.memmap(path, mode='r', order='C') self._bin_buffer = memoryview(self._bin_buffer_mmap) -# print_rank_0(" reading sizes...") + print_rank_0(" reading sizes...") self._sizes = np.frombuffer( self._bin_buffer, dtype=np.int32, count=self._len, offset=offset) -# print_rank_0(" reading pointers...") + print_rank_0(" reading pointers...") self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes) -# print_rank_0(" reading document index...") + print_rank_0(" reading document index...") self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count, offset=offset + self._sizes.nbytes + self._pointers.nbytes) @@ -591,8 +618,8 @@ def merge_file_(self, another_file): index = MMapIndexedDataset.Index(index_file_path(another_file)) assert index.dtype == self._dtype -# total_len = len(index.sizes)+len(self._sizes) -# print(f" concat {another_file} size={len(index.sizes)} for a total size of {total_len}") + total_len = len(index.sizes)+len(self._sizes) + print(f" concat {another_file} size={len(index.sizes)} for a total size of {total_len}") offset = len(self._sizes) self._sizes.extend(index.sizes) @@ -615,62 +642,73 @@ def finalize(self, index_file): # spot, and copy each file. def gather_files_dist_bin(outfile, filelist, distctx): """Concatenate binary files in filelist into a new file given by outfile""" - import stat - import shutil - # lookup size of each of our binary files filesizes = [os.stat(data_file_path(f))[stat.ST_SIZE] for f in filelist] - # compute offset this rank should start copying - # its data into the merged file + # compute total bytes of the merged file and the offset + # at which this rank will write data from its files numbytes = sum(filesizes) + count = distctx.sum(numbytes) offset = distctx.exscan(numbytes) - # Create shared output file. - with distctx.open(data_file_path(outfile)) as fout: - # Seek to appropriate starting offset in the merged file. - fout.seek(offset) + # We first write to a temporary file name. We rename to the final name + # if successful or delete the temporary file if not. + # This way if the final name appears, the user knows it's a valid file. + finalname = data_file_path(outfile) + finalnametmp = finalname + ".tmp" + + # First delete the final file if it already exists + distctx.remove(finalname) + + # Catch I/O errors from any process + err = None + try: + # Create shared output file and pre-truncate to its final size. + with distctx.open(finalnametmp, truncate=count) as fout: + # Seek to appropriate starting offset in the merged file. + fout.seek(offset) - # Copy in contents of each of our files. - for f in filelist: - with open(data_file_path(f), "rb") as fsrc: - shutil.copyfileobj(fsrc, fout) + # Copy in contents of each of our files. + for f in filelist: + with open(data_file_path(f), "rb") as fsrc: + shutil.copyfileobj(fsrc, fout) - # TODO: check that all ranks wrote successfully - distctx.barrier() + except Exception as e: + err = e + # Check that all ranks wrote successfully. + # This will raise an exception all on ranks if we detect + # an exception on any rank. + distctx.allraise_if(err) -def write_list(fout, pos, vals, shift, offset, total, dtype): - """Write list of values to fout and return the number of bytes written assuming the total list size. + # Everyone wrote their part successfully. + # Rename the temporary file to the final file. + distctx.rename(finalnametmp, finalname) + + +def write_list_at_offset(fout, file_offset, vals, shift, elem_offset, dtype): + """Write list of vals to fout starting at an offset given by file_offset, elem_offset, and dtype. Copies list of values in vals to a numpy array of type dtype. - Adds a constant value to all elements as given in shift. + Adds a constant shift value to all elements. Writes the numpy array to the file handle at given offset and scaled by size of the datatype. - byteoffset = pos + vals * dtype().itemsize - Computes and return the total bytes written to write total elements of type dtype. + offset = file_offset + elem_offset * dtype().itemsize Parameters ---------- fout : file handle - Opened file handle to which to write vals - pos : int + Open file handle to which to write list of vals + file_offset : int Byte offset within the file where the global list starts - vals : list(int) + vals : list[int] List of values to be written shift : int Value to add to each element in vals before writing (use 0 for no change) - offset : int - Zero-based element index where vals starts within the global list - total : int - Total number of elements within the global list - dtype : numpy datatype + elem_offset : int + Zero-based element index where vals starts within the global list. + This value is scaled by dtype().itemsize to convert to a corresponding byte offset. + dtype : np.dtype numpy datatype to be used when writing the list to the file - - Returns - ------- - int - Number of bytes that would be required to write the global list - of length 'total' and of type 'dtype' """ # Make a copy of the vals list using the requested datatype. @@ -681,22 +719,21 @@ def write_list(fout, pos, vals, shift, offset, total, dtype): # Seek to proper offset for this rank and write # values into file, stored as given datatype. - fout.seek(pos + offset * dtype().itemsize) + fout.seek(file_offset + elem_offset * dtype().itemsize) fout.write(npvals.tobytes(order='C')) - # Return number of bytes written - return total * dtype().itemsize +def gather_files_dist_check_dtype(filelist, dtype_rank_consistent, dtype_code, distctx): + # Verify that no rank has found an inconsistent value in its own set of files. + # This includes an allreduce to verify that dtype_rank_consistent is True everywhere. + distctx.allassert(dtype_rank_consistent, "Some rank found inconsistent dtype values") -def gather_files_dist_check_dtype(filelist, dtype_valid, dtype_code, distctx): - # verify that no rank has found an inconsistent value in its own set of files - distctx.allassert(dtype_valid, "Some rank found inconsistent dtype values") - - # verify that at least one rank found a dtype value + # Verify that at least one rank found a dtype value. + # Because of the bcast, the the value of first_dtype_code is the same on all ranks. first_dtype_code = distctx.bcast_first(dtype_code) assert first_dtype_code is not None, "Failed to find a dtype value in any index file" - # verify that the dtype is consistent on all ranks, if a rank has a dtype value + # Verify that the dtype is consistent on all ranks, if a rank has a dtype value. distctx.allassert(dtype_code == first_dtype_code or dtype_code is None, "Different dtype values detected in index files") # return the dtype @@ -708,9 +745,9 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx): sizes = [] data_offsets = [0] dim_offsets = [0] - docs = [0] - dtype_valid = True # whether rank identifies inconsistent values in its files - dtype_value = None # the current dtype code, if any + doc_idx = [0] + dtype_rank_consistent = True # whether this rank identifies inconsistent dtype values in its files + dtype_value = None # the current dtype code to compare against, if any for f in filelist: # read index file for this file index = IndexedDataset(f) @@ -720,41 +757,41 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx): sizes.extend(index.sizes) data_offsets.extend(index.data_offsets[1:] + data_offsets[-1]) dim_offsets.extend(index.dim_offsets[1:] + dim_offsets[-1]) - docs.extend(index.doc_idx[1:] + doc_offset) + doc_idx.extend(index.doc_idx[1:] + doc_offset) # check that the dtype in this index matches the dtype in our other files dtype_code = code(index.dtype) if dtype_value is None: dtype_value = dtype_code if dtype_value != dtype_code: - dtype_valid = False + dtype_rank_consistent = False - # Check that we have consistent dtypes in all files from all ranks - dtype = gather_files_dist_check_dtype(filelist, dtype_valid, dtype_value, distctx) + # Check that we have consistent dtypes in all files from all ranks, + # and return the dtype being used. + dtype = gather_files_dist_check_dtype(filelist, dtype_rank_consistent, dtype_value, distctx) - # Capture the last value in dim and data arrays before we delete any items. + # Capture the last value in the data array before we delete any items. # Note this may be zero on any rank that has no items, # but zero is the correct value in that case. - # We use this last value to compute a shift value that will - # later be added to each element in our lists. - dim_shift = distctx.exscan(dim_offsets[-1]) + # We use this last value to compute a shift value that + # is later be added to each element in our data list. data_shift = distctx.exscan(data_offsets[-1]) # Drop the zero entry from the lists that start with - # a "0" value unless we're rank 0 + # a "0" value unless we're rank 0. if distctx.rank != 0: del data_offsets[0] del dim_offsets[0] - del docs[0] + del doc_idx[0] # Compute total number of entires in data, size, dim, - # and docs lists across all ranks. Also compute the offset + # and doc_idx lists across all ranks. Also compute the offset # of the calling rank for each list considering the number # of entries for all ranks before the calling rank. numdata = len(data_offsets) numsize = len(sizes) numdim = len(dim_offsets) - numdoc = len(docs) + numdoc = len(doc_idx) global_data_count = distctx.sum(numdata) global_size_count = distctx.sum(numsize) @@ -766,53 +803,84 @@ def gather_files_dist_idx_cached(outfile, filelist, distctx): global_dim_offset = distctx.exscan(numdim) global_doc_offset = distctx.exscan(numdoc) - # Create shared output file - with distctx.open(index_file_path(outfile)) as fout: - # Have rank 0 write the file header - # Broadcast number of bytes written from rank 0, - # and advance file position past file header on all ranks. - pos = 0 - if distctx.rank == 0: - pos = IndexedDatasetBuilder.write_header(fout, dtype, global_data_count, global_size_count, global_doc_count) - pos = distctx.bcast(pos, root=0) - - # TODO: is dim_shift == global_size_offset? - # The dimension list records the offset within - # the sizes list for each sentence. - # Adjust dimension index values for number of size values that - # come before the calling rank which is in dim_shift. - pos += write_list(fout, pos, dim_offsets, dim_shift, global_dim_offset, global_dim_count, np.int64) - - # The data index records the element offset to the start of each - # sentence within the binary data file, expressed in units of dtype().itemsize. - # Adjust data index values for number of elements that - # come before the calling rank, which is in data_shift. - pos += write_list(fout, pos, data_offsets, data_shift, global_data_offset, global_data_count, np.int64) - - # Each sentence is stored as a tensor. - # The tensor for each sentence can be multidimensional. - # The number of tensor dimensions per sentence is variable, - # and the size of each dimension of a sentence is arbitrary. - # The size list records a flattened list of the sizes - # for each dimension of a sentence. - pos += write_list(fout, pos, sizes, 0, global_size_offset, global_size_count, np.int64) - - # The document index points to the position in the sizes - # array for the first sentence of each document. - # Adjust document index for number of sentences that - # come before the calling rank which is in global_size_offset. - pos += write_list(fout, pos, docs, global_size_offset, global_doc_offset, global_doc_count, np.int64) - - # TODO: check that all ranks wrote successfully - distctx.barrier() + # We first write to a temporary file name. We rename to the final name + # if successful or delete the temporary file if not. + # This way if the final name appears, the user knows it's a valid file. + finalname = index_file_path(outfile) + finalnametmp = finalname + ".tmp" + + # First delete the final file if it already exists + distctx.remove(finalname) + + # Catch and I/O errors to later determine whether all ranks wrote successfully. + err = None + try: + # Create shared output file + with distctx.open(finalnametmp) as fout: + # Have rank 0 write the file header + file_offset = 0 + if distctx.rank == 0: + try: + file_offset = fout.tell() + file_offset += IndexedDatasetBuilder.write_header(fout, dtype, global_data_count, global_size_count, global_doc_count) + except Exception as e: + err = e + distctx.allraise_if(err) + + # Broadcast current file position from rank 0. + file_offset = distctx.bcast(file_offset, root=0) + + # The dimension list records the offset within + # the sizes list for each sentence. + # We shift our dimension index values to account for the number of size values + # that come before the calling rank which is in global_size_offset. + write_list_at_offset(fout, file_offset, dim_offsets, global_size_offset, global_dim_offset, np.int64) + file_offset += global_dim_count * np.int64().itemsize + + # The data index records the element offset to the start of each + # sentence within the binary data file. Note that this is an + # element offset, not a byte offset. Each element is pyhsically stored + # in the data file as dtype().itemsize bytes. + # We shift the data index values according to the number of elements that + # come before the calling rank, which is stored in data_shift. + write_list_at_offset(fout, file_offset, data_offsets, data_shift, global_data_offset, np.int64) + file_offset += global_data_count * np.int64().itemsize + + # Each sentence is stored as a tensor. + # The tensor for each sentence can be multidimensional. + # The number of tensor dimensions per sentence is variable, + # and the size of each dimension of a sentence is arbitrary. + # The size list records a flattened list of the sizes + # for each dimension of a sentence. + # No shift value is needed. + write_list_at_offset(fout, file_offset, sizes, 0, global_size_offset, np.int64) + file_offset += global_size_count * np.int64().itemsize + + # The document index records the offset within the sizes + # array for the first sentence of each document. + # We shift the document index values for number of size values that + # come before the calling rank which is in global_size_offset. + write_list_at_offset(fout, file_offset, doc_idx, global_size_offset, global_doc_offset, np.int64) + file_offset += global_doc_count * np.int64().itemsize + + except Exception as e: + # if we encounter any exception while writing, store it for later + err = e + + # Check that all ranks wrote successfully + distctx.allraise_if(err) + + # Everyone wrote their part successfully. + # Rename the temporary file to the final file. + distctx.rename(finalnametmp, finalname) def gather_files_dist_idx_mmap(outfile, filelist, distctx): - # Read each index file and append items to the size and docs lists + # Read each index file and append items to the size and doc_idx lists sizes = [] - docs = [0] - dtype_valid = True # whether rank identifies inconsistent values in its files - dtype_value = None # the current dtype code, if any + doc_idx = [0] + dtype_rank_consistent = True # whether rank identifies inconsistent dtype values in its files + dtype_value = None # the current dtype code to compare against, if any for f in filelist: # read index file for this file index = MMapIndexedDataset.Index(index_file_path(f)) @@ -820,22 +888,23 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): # append its size and doc entries to our lists docs_offset = len(sizes) sizes.extend(index.sizes) - docs.extend(index.doc_idx[1:] + docs_offset) + doc_idx.extend(index.doc_idx[1:] + docs_offset) # check that the dtype in this index matches the dtype in our other files dtype_code = code(index.dtype) if dtype_value is None: dtype_value = dtype_code if dtype_value != dtype_code: - dtype_valid = False + dtype_rank_consistent = False - # Check that we have consistent dtypes in all files from all ranks - dtype = gather_files_dist_check_dtype(filelist, dtype_valid, dtype_value, distctx) + # Check that we have consistent dtypes in all files from all ranks, + # and return the dtype being used. + dtype = gather_files_dist_check_dtype(filelist, dtype_rank_consistent, dtype_value, distctx) # Drop the zero entry from the lists that start with # a "0" value unless we're rank 0 if distctx.rank != 0: - del docs[0] + del doc_idx[0] # Compute total number of size and document index # values across all ranks. Also compute the offset @@ -843,7 +912,7 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): # the values of sizes/docs for all ranks before the # calling rank. numsizes = len(sizes) - numdocs = len(docs) + numdocs = len(doc_idx) global_size_count = distctx.sum(numsizes) global_docs_count = distctx.sum(numdocs) @@ -851,65 +920,76 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): global_size_offset = distctx.exscan(numsizes) global_docs_offset = distctx.exscan(numdocs) - # Create shared output file - with distctx.open(index_file_path(outfile)) as fout: - # Have rank 0 write the file header - # Broadcast number of bytes written from rank 0, - # and advance file position past file header on all ranks. - pos = 0 - if distctx.rank == 0: - pos = MMapIndexedDataset.Index.write_header(fout, dtype, global_size_count, global_docs_count) - pos = distctx.bcast(pos, root=0) - - # The list of size values from each rank are - # concatenated and stored as int32. - pos += write_list(fout, pos, sizes, 0, global_size_offset, global_size_count, np.int32) - - # The pointer values store the byte offset to each sentence when in memory. - # A sentence has a variable number of tokens, given by - # its corresponding entry in the size array. Each token - # of a sentence is stored in units of type dtype, which consumes - # dtype().itemsize bytes (often a standard type that is just - # large enough to represent all elements of the vocabulary). - - # Compute byte sizes for each of our sentences given - # the token count and vocab dtype. - bytesizes = np.array(sizes, dtype=np.int64) - bytesizes *= dtype().itemsize - - # Inclusive scan to sum number of bytes over sentences. - pointers = np.cumsum(bytesizes, axis=0) - - # Account for bytes for all sentences on ranks - # before the calling rank. - bytes_last = pointers[-1] if len(sizes) > 0 else 0 - pointer_offset = distctx.exscan(bytes_last) - pointers += pointer_offset - - # Convert to exclusive scan to get global offset. - pointers -= bytesizes - - # Since the pointers array is the same length as the sizes array, - # we use global_size_offset and global_size_count to position - # within the file for writing the pointer values. - - # Seek to proper offset for this rank and write - # pointer values into file, stored as int64. - fout.seek(pos + global_size_offset * np.int64().itemsize) - fout.write(pointers.tobytes(order='C')) - - # Advance past list of pointer values - pos += global_size_count * np.int64().itemsize - - # The document index points to the position in the sizes - # array for the starting sentence of each document. - # A variable number of sentences can be in each document. - # Adjust document index for number of sentences that - # come before the calling rank which is in global_size_offset. - pos += write_list(fout, pos, docs, global_size_offset, global_docs_offset, global_docs_count, np.int64) - - # TODO: check that all ranks wrote successfully - distctx.barrier() + # Compute local byte offsets for each of our sentences given + # the token count and byte size of the vocab dtype. + pointers, pointers_bytes = get_pointers_with_total(sizes, dtype().itemsize, np.int64) + + # Determine total number of bytes for all sentences on ranks + # before the calling rank. + pointer_offset = distctx.exscan(pointers_bytes) + + # We first write to a temporary file name. We rename to the final name + # if successful or delete the temporary file if not. + # This way if the final name appears, the user knows it's a valid file. + finalname = index_file_path(outfile) + finalnametmp = finalname + ".tmp" + + # First delete the final file if it already exists + distctx.remove(finalname) + + # Catch and I/O errors to later determine whether all ranks wrote successfully. + err = None + try: + # Create shared output file + with distctx.open(finalnametmp) as fout: + # Have rank 0 write the file header + file_offset = 0 + if distctx.rank == 0: + try: + file_offset = fout.tell() + file_offset += MMapIndexedDataset.Index.write_header(fout, dtype, global_size_count, global_docs_count) + except Exception as e: + err = e + distctx.allraise_if(err) + + # Broadcast current file position from rank 0. + file_offset = distctx.bcast(file_offset, root=0) + + # The list of size values from each rank are + # concatenated and stored as int32. + write_list_at_offset(fout, file_offset, sizes, 0, global_size_offset, np.int32) + file_offset += global_size_count * np.int32().itemsize + + # The pointer values store the byte offset to each sentence when in memory. + # A sentence has a variable number of tokens, given by + # its corresponding entry in the size array. Each token + # of a sentence is stored in units of type dtype, which consumes + # dtype().itemsize bytes (often a standard type that is just + # large enough to represent all elements of the vocabulary). + # Since the pointers array is the same length as the sizes array, + # we use global_size_offset and global_size_count to position + # within the file for writing the pointer values. + write_list_at_offset(fout, file_offset, pointers, pointer_offset, global_size_offset, np.int64) + file_offset += global_size_count * np.int64().itemsize + + # The document index points to the position in the sizes + # array for the starting sentence of each document. + # A variable number of sentences can be in each document. + # We shift the document index for number of sentences that + # come before the calling rank which is in global_size_offset. + write_list_at_offset(fout, file_offset, doc_idx, global_size_offset, global_docs_offset, np.int64) + file_offset += global_docs_count * np.int64().itemsize + + except Exception as e: + # if we encounter any exception while writing, store it for later + err = e + + # Check that all ranks wrote successfully + distctx.allraise_if(err) + + # Everyone wrote their part successfully. + # Rename the temporary file to the final file. + distctx.rename(finalnametmp, finalname) # Verify that all files in filelist are of the same index type. @@ -917,14 +997,10 @@ def gather_files_dist_idx_mmap(outfile, filelist, distctx): def gather_files_dist_check_impltype(filelist, distctx): # Sanity check for typos in file names. # Check that a data file exists for each of our files. - exists = True - for f in filelist: - binfile = data_file_path(f) - if not os.path.exists(binfile): - exists = False + all_files_exist = all([os.path.exists(data_file_path(f)) for f in filelist]) # Check that all ranks have all of their files. - distctx.allassert(exists, "Some rank is missing its input file") + distctx.allassert(all_files_exist, "Some rank is missing its input file") # map type string to an integer for easier bcast, use 0 for unknown implmap = {"cached": 1, "mmap": 2} @@ -958,6 +1034,9 @@ def gather_files_dist_check_impltype(filelist, distctx): if implmap[key] == bcasttype: return key + # raise exception if key for bcasttype was not found + raise UnreachableCode + # Collectively merge files into a new output file specified in filemain. # Each rank contributes a distinct list of zero or more files in filelist, @@ -973,7 +1052,7 @@ def gather_files_dist_check_impltype(filelist, distctx): def gather_files_dist(filemain, filelist, distctx): # Check that at least one input file is listed filecount = distctx.sum(len(filelist)) - assert filecount > 0, "No rank has any input files to merge" + assert filecount > 0, "All ranks have no input files to merge" # Check that files are all of the same index type indexstr = gather_files_dist_check_impltype(filelist, distctx) diff --git a/megatron/global_vars.py b/megatron/global_vars.py index 5f3c28463..1d94e8cf7 100644 --- a/megatron/global_vars.py +++ b/megatron/global_vars.py @@ -19,6 +19,8 @@ import sys import time +from pathlib import Path + import torch from megatron.tokenizer import build_tokenizer @@ -29,10 +31,10 @@ _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None _GLOBAL_TOKENIZER = None _GLOBAL_TENSORBOARD_WRITER = None +_GLOBAL_CODECARBON_TRACKER = None _GLOBAL_ADLR_AUTORESUME = None _GLOBAL_TIMERS = None - def get_args(): """Return arguments.""" _ensure_var_is_initialized(_GLOBAL_ARGS, 'args') @@ -63,6 +65,10 @@ def get_tensorboard_writer(): to check if it is initialized.""" return _GLOBAL_TENSORBOARD_WRITER +def get_codecarbon_tracker(): + """Return codecarbon tracker. It can be None so no need + to check if it is initialized.""" + return _GLOBAL_CODECARBON_TRACKER def get_adlr_autoresume(): """ADLR autoresume object. It can be None so no need @@ -86,6 +92,7 @@ def set_global_variables(extra_args_provider=None, args_defaults={}, if args.vocab_file or args.tokenizer_name_or_path: _ = _build_tokenizer(args) _set_tensorboard_writer(args) + _set_codecarbon_tracker(args) _set_adlr_autoresume(args) _set_timers() @@ -145,6 +152,56 @@ def _set_tensorboard_writer(args): 'no TensorBoard logs will be written.', flush=True) +def _set_codecarbon_tracker(args): + global _GLOBAL_CODECARBON_TRACKER + if not hasattr(args, 'codecarbon_dir') or args.codecarbon_dir is None: + return + + import codecarbon + if args.rank == 0: + print('> setting codecarbon ...') + + output_dir = args.codecarbon_dir + output_file = f"emissions-{args.rank:03d}.csv" + log_level = "warning" + country_iso_code="FRA" + + Path(output_dir).mkdir(parents=True, exist_ok=True) + _GLOBAL_CODECARBON_TRACKER = codecarbon.OfflineEmissionsTracker( + output_dir=output_dir, + output_file=output_file, + log_level=log_level, + country_iso_code=country_iso_code, + ) + + +def codecarbon_tracker_start(): + global _GLOBAL_CODECARBON_TRACKER + if _GLOBAL_CODECARBON_TRACKER is None: + return + + #print("CC START") + _GLOBAL_CODECARBON_TRACKER.start() + + +def codecarbon_tracker_stop(): + global _GLOBAL_CODECARBON_TRACKER + if _GLOBAL_CODECARBON_TRACKER is None: + return + + #print("CC STOP") + _GLOBAL_CODECARBON_TRACKER.stop() + + +def codecarbon_tracker_flush(): + global _GLOBAL_CODECARBON_TRACKER + if _GLOBAL_CODECARBON_TRACKER is None: + return + + #print("CC FLUSH") + _GLOBAL_CODECARBON_TRACKER.flush() + + def _set_adlr_autoresume(args): """Initialize ADLR autoresume.""" global _GLOBAL_ADLR_AUTORESUME diff --git a/megatron/model/activations.py b/megatron/model/glu_activations.py similarity index 84% rename from megatron/model/activations.py rename to megatron/model/glu_activations.py index 82ccdf098..9e0eb5b29 100644 --- a/megatron/model/activations.py +++ b/megatron/model/glu_activations.py @@ -7,10 +7,10 @@ class _GLUBaseModule(nn.Module): def __init__(self, activation_fn): super().__init__() self.activation_fn = activation_fn - + def forward(self, x): # dim=-1 breaks in jit for pt<1.10 - x1, x2 = x.chunk(2, dim=(x.ndim-1)) + x1, x2 = x.chunk(2, dim=(x.ndim - 1)) return x1 * self.activation_fn(x2) @@ -38,3 +38,11 @@ def __init__(self): geglu = torch.jit.script(GEGLU()) reglu = torch.jit.script(ReGLU()) swiglu = torch.jit.script(SwiGLU()) + + +GLU_ACTIVATIONS = { + "geglu": geglu, + "liglu": liglu, + "reglu": reglu, + "swiglu": swiglu, +} diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e7612b76f..473b8e06b 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -30,6 +30,7 @@ import deepspeed +from .glu_activations import GLU_ACTIVATIONS from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb # flags required to enable jit fusion kernels @@ -76,7 +77,9 @@ def __init__(self, init_method, output_layer_init_method): self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu - if args.openai_gelu: + if args.glu_activation: + self.activation_func = GLU_ACTIVATIONS[args.glu_activation] + elif args.openai_gelu: self.activation_func = openai_gelu elif args.onnx_safe: self.activation_func = erf_gelu diff --git a/megatron/testing_utils.py b/megatron/testing_utils.py index d402b59df..801d74d9b 100644 --- a/megatron/testing_utils.py +++ b/megatron/testing_utils.py @@ -25,6 +25,7 @@ import random from distutils.util import strtobool from io import StringIO +from packaging import version from pathlib import Path from typing import Iterator, Union from unittest import mock @@ -207,6 +208,35 @@ def get_gpu_count(): else: return 0 +def torch_assert_equal(actual, expected): + """ emulates the removed torch.testing.assert_equal """ + torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0) + + +def is_torch_bf16_available(): + # from https://github.com/huggingface/transformers/blob/26eb566e43148c80d0ea098c76c3d128c0281c16/src/transformers/file_utils.py#L301 + if is_torch_available(): + import torch + if not torch.cuda.is_available() or torch.version.cuda is None: + return False + if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8: + return False + if int(torch.version.cuda.split(".")[0]) < 11: + return False + if not version.parse(torch.__version__) >= version.parse("1.09"): + return False + return True + else: + return False + + +def require_torch_bf16(test_case): + """Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.9.""" + if not is_torch_bf16_available(): + return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.9")(test_case) + else: + return test_case + def get_tests_dir(append_path=None): """ diff --git a/megatron/training.py b/megatron/training.py index 21ef13b94..f66544dff 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -51,6 +51,7 @@ from megatron.schedules import forward_backward_pipelining_without_interleaving from megatron.schedules import forward_backward_pipelining_with_interleaving from megatron.utils import report_memory, flops_calculator +from megatron.global_vars import codecarbon_tracker_start, codecarbon_tracker_stop import deepspeed @@ -95,6 +96,8 @@ def pretrain(train_valid_test_dataset_provider, initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args_defaults) + codecarbon_tracker_start() + # Adjust the startup time so it reflects the largest value. # This will be closer to what scheduler will see (outside of # image ... launches. @@ -162,6 +165,9 @@ def pretrain(train_valid_test_dataset_provider, test_data_iterator, model, 0, True) + codecarbon_tracker_stop() + + def update_train_iters(args): # For iteration-based training, we don't need to do anything diff --git a/requirements.txt b/requirements.txt index 234a9902a..a96ff42be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ regex numpy transformers # git+https://github.com/microsoft/DeepSpeed.git@big-science +# edit to a higher SHA or future release if needed +git+git://github.com/mlco2/codecarbon.git@03479b695a771c28df6b877a809f5af3eb9ef3b8 diff --git a/tests/test_activations.py b/tests/test_activations.py index 85c949f4a..34097ad22 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -4,8 +4,8 @@ import torch from torch.nn import functional as F -from megatron.model.activations import liglu, geglu, reglu, swiglu -from megatron.testing_utils import set_seed +from megatron.model.glu_activations import GLU_ACTIVATIONS, geglu, liglu, reglu, swiglu +from megatron.testing_utils import set_seed, torch_assert_equal class TestActivations(unittest.TestCase): @@ -17,26 +17,34 @@ def setUp(self): self.num_channels = random.randint(1, 384) * 2 self.x = torch.randn(self.batch_size, self.seq_len, self.num_channels) self.x1, self.x2 = self.x.chunk(2, dim=-1) + # glu should halve the last dimension + self.output_shape = [self.batch_size, self.seq_len, self.num_channels // 2] def test_shapes(self): - # glu should halve the last dimension - output_shape = [self.batch_size, self.seq_len, self.num_channels // 2] - for activation_fn in [liglu, geglu, reglu, swiglu]: + for activation_fn in GLU_ACTIVATIONS.values(): output = activation_fn(self.x) - self.assertEqual(list(output.shape), output_shape) + self.assertEqual(list(output.shape), self.output_shape) def test_liglu(self): expected = self.x1 * self.x2 - torch.testing.assert_equal(liglu(self.x), expected) + torch_assert_equal(liglu(self.x), expected) def test_geglu(self): expected = self.x1 * F.gelu(self.x2) - torch.testing.assert_equal(geglu(self.x), expected) + torch_assert_equal(geglu(self.x), expected) def test_reglu(self): expected = self.x1 * F.relu(self.x2) - torch.testing.assert_equal(reglu(self.x), expected) + torch_assert_equal(reglu(self.x), expected) def test_swiglu(self): expected = self.x1 * F.silu(self.x2) - torch.testing.assert_equal(swiglu(self.x), expected) + torch_assert_equal(swiglu(self.x), expected) + + # from megatron.testing_utils import require_torch_bf16 + # @require_torch_bf16 + # def test_bf16_jit(self): + # x_bf16 = self.x.to(torch.bfloat16) + # for activation_fn in GLU_ACTIVATIONS.values(): + # output = activation_fn(x_bf16) + # self.assertEqual(list(output.shape), self.output_shape) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 3caf9f552..3723f97bc 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import filecmp import io import json import re @@ -77,3 +77,92 @@ def test_preprocess_data(self): for ext in ["bin", "idx"]: tgt_path = f"{output_prefix}_text_document.{ext}" self.assertTrue(Path(tgt_path).exists(), ) + + def compare_meg_data_files(self, tgt, ref): + for ext in ["bin", "idx"]: + tgt_path = f"{tgt}.{ext}" + ref_path = f"{ref}.{ext}" + self.assertTrue(Path(tgt_path).exists(), ) + self.assertTrue(filecmp.cmp(tgt_path, ref_path, shallow=False)) + + def test_process_data_microsoft(self): + """We want to be stable to Microsoft version.""" + src_dir = self.src_dir + data_dir = f"{self.data_dir}/gpt2" + output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) + + input_path = f"{self.tests_dir}/tools/openwebtext-1000.jsonl" + + output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext" + + cmd = f""" + python {src_dir}/tools/preprocess_data.py + --input {input_path} + --output-prefix {output_prefix} + --dataset-impl mmap + --tokenizer-type GPT2BPETokenizer + --merge-file {data_dir}/gpt2-tiny-merges.txt + --vocab {data_dir}/gpt2-tiny-vocab.json + --append-eod + --workers 2 + """.split() + + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + execute_subprocess_async(cmd, env=self.get_env()) + + self.compare_meg_data_files(f"{output_prefix}_text_document", f"{data_dir}/meg-gpt2-openwebtext_text_document") + + def test_process_data_dist_microsoft(self): + """We want to be stable to Microsoft version.""" + src_dir = self.src_dir + data_dir = f"{self.data_dir}/gpt2" + output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) + + output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext_1k" + + cmd = f""" + python -m torch.distributed.launch --nproc_per_node 2 {src_dir}/tools/preprocess_data_dist.py + --input openwebtext-10k + --count 1000 + --output-prefix {output_prefix} + --dataset-impl mmap + --tokenizer-type GPT2BPETokenizer + --merge-file {data_dir}/gpt2-tiny-merges.txt + --vocab {data_dir}/gpt2-tiny-vocab.json + --append-eod + """.split() + + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + execute_subprocess_async(cmd, env=self.get_env()) + + self.compare_meg_data_files(f"{output_prefix}_text_document", f"{data_dir}/meg-gpt2-openwebtext_text_document") + + def test_process_data_dist_serial_microsoft(self): + """We want to be stable to Microsoft version.""" + src_dir = self.src_dir + data_dir = f"{self.data_dir}/gpt2" + output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) + + output_prefix = f"{output_dir}/test-ds-meg-gpt2-openwebtext_1k" + + cmd = f""" + python -m torch.distributed.launch --nproc_per_node 2 {src_dir}/tools/preprocess_data_dist.py + --input openwebtext-10k + --count 1000 + --merge serial + --output-prefix {output_prefix} + --dataset-impl mmap + --tokenizer-type GPT2BPETokenizer + --merge-file {data_dir}/gpt2-tiny-merges.txt + --vocab {data_dir}/gpt2-tiny-vocab.json + --append-eod + """.split() + + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + execute_subprocess_async(cmd, env=self.get_env()) + + self.compare_meg_data_files(f"{output_prefix}_text_document", f"{data_dir}/meg-gpt2-openwebtext_text_document") + diff --git a/tests/test_training.py b/tests/test_training.py index 7306615f1..85d1a537b 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -90,6 +90,7 @@ def test_training_all(self): --eval-interval 10 --eval-iters 5 --checkpoint-activations + --glu-activation geglu --exit-interval {exit_interval} --merge-file {data_dir}/gpt2-tiny-merges.txt @@ -97,6 +98,7 @@ def test_training_all(self): --save {output_dir}/checkpoints --load {output_dir}/checkpoints --data-path {data_dir}/meg-gpt2-openwebtext_text_document + --codecarbon-dir {output_dir}/codecarbon --tensorboard-dir {output_dir}/tensorboard --tensorboard-queue-size 5 --log-timers-to-tensorboard diff --git a/tools/indexed_json.py b/tools/indexed_json.py new file mode 100644 index 000000000..9ebf12aed --- /dev/null +++ b/tools/indexed_json.py @@ -0,0 +1,354 @@ +import os +import stat +import json +import struct +import time +import numpy as np + +class IndexedJSON(object): + def __init__(self, filename, distctx, bufsize=16*1024*1024, progress=10.0): + self.filename = filename # JSON file name + self.distctx = distctx # distributed environment for collective ops + self.bufsize = bufsize # buffer size used while building index + self.numsamples = 0 # number of records in JSON file + self.fh_idx = None # file handle to JSON index file + self.fh_json = None # file handle to JSON file + self.time_index = 0 # record cost to construct index + self.progress = progress # number of secs between progress msgs (0 to disable) + + # given a JSON file name, compute the name of its index file + self.filename_idx = self.index_filename(self.filename) + + # determine whether we need to create the index + create_index = False + exists = self.test_exists(self.filename_idx) + if not exists: + # index file does not exist + create_index = True + else: + # index exists, but rebuild the index if the original file + # has been modified since the index was built + mtime = self.get_mtime(self.filename) + mtime_idx = self.get_mtime(self.filename_idx) + if mtime > mtime_idx: + create_index = True + if create_index: + self.create_index(self.filename, self.bufsize) + + # Open the index and the json files for reading. + # Disable buffering to avoid reading extra bytes we won't use. + self.fh_idx = open(self.filename_idx, "rb", buffering=0) + self.fh_json = open(self.filename, "rb", buffering=0) +# self.fh_idx = open(self.filename_idx, "rb") +# self.fh_json = open(self.filename, "rb") + + self.read_index_header() + + # Identify number of samples in JSON file. + # For now, we can do that using the size of index file. + self.numsamples = None + if self.idx_version == 1: + # version 1 has a 16-byte header + # followed by a list of (offset, length) pairs of uint64 + header_size = 16 + filesize_idx = self.get_filesize(self.filename_idx) + self.numsamples = int((filesize_idx - header_size) / 16) + + def test_exists(self, filename): + """Test whether file exists and broadcast result to all ranks.""" + exists = False + if self.distctx.rank == 0: + exists = os.path.exists(filename) + exists = self.distctx.bcast(exists, root=0) + return exists + + def get_filesize(self, filename): + """Lookup filesize and broadcast to all ranks.""" + filesize = 0 + if self.distctx.rank == 0: + filesize = os.stat(filename)[stat.ST_SIZE] + filesize = self.distctx.bcast(filesize, root=0) + return filesize + + def get_mtime(self, filename): + """Lookup file mtime and broadcast to all ranks.""" + mtime = 0 + if self.distctx.rank == 0: + mtime = os.stat(filename)[stat.ST_MTIME] + mtime = self.distctx.bcast(mtime, root=0) + return mtime + + def read_index_header(self): + """Read header from index file and check its version.""" + # Rank 0 reads the header, and bcasts its version + version = None + if self.distctx.rank == 0: + try: + # Seek to the front of the file + self.fh_idx.seek(0) + + # Read the magic valud and check that it matches what we expect + magic = self.fh_idx.read(8) + if magic == b'INDXJSON': + # Good magic value, now read file format version number + buf = self.fh_idx.read(8) + if len(buf) == 8: + version = struct.unpack(">Q", buf)[0] + except Exception as e: + pass + + # Get version from rank 0 (should be None on any error) + self.idx_version = self.distctx.bcast(version, root=0) + + # Check that we have a version number that we support + if self.idx_version != 1: + raise ValueError("Unknown index file format version '{self.idx_version}' in file '{self.filename_idx}'") + + def index_filename(self, filename): + """Given the name of a JSON file, return the name of its index file.""" + return filename + '.idx' + + def get_start_end(self, num): + """Given num items, compute and return [start,end) range on each rank.""" + rank = self.distctx.rank + num_ranks = self.distctx.numranks + + num_per_rank = num // num_ranks + remainder = num % num_ranks + if rank < remainder: + start = (num_per_rank + 1) * rank; + end = start + (num_per_rank + 1) + else: + start = (num_per_rank + 1) * remainder + num_per_rank * (rank - remainder); + end = start + num_per_rank + return start, end + + def create_index(self, filename, bufsize): + """Given a JSON file named dataset.jsonl, write index to dataset.jsonl.idx.""" + + # To compute this index, ranks collective scan the JSON + # file and record the byte offset of newline characters. + # Those byte offsets are accumulated in a temporary index file + # until the entire JSON file has been scanned. The processes + # then read back those byte locations from the temporary file + # to compute the length of each record. Finally for each + # record an (offset,length) pair of int64 types is written into + # the index file to specify the starting offset and length of + # each record in the JSON file. + + time_start = time.time() + rank = self.distctx.rank + numranks = self.distctx.numranks + + # define file names for the index and the temporary index file + filename_idx = self.index_filename(filename) + filename_tmp = filename_idx + 'tmp' + + # lookup and broadcast size of JSON file to all ranks + filesize = self.get_filesize(filename) + + # if progress messages are enabled, print a header about what we're doing + if rank == 0 and self.progress > 0.0: + print(f"Indexing '{filename}' of {filesize} bytes ...", flush=True) + + # create the temporary index file, shared across all ranks + with self.distctx.open(filename_tmp) as ftmp: + # open and scan the JSON file + time_next = time_start + self.progress + recstart = 0 + with open(filename, "rb") as f: + curpos = 0 + while curpos < filesize: + # each rank reads a section of the file + offset = curpos + bufsize * rank + f.seek(offset) + data = f.read(bufsize) + + # scan section for newline chars, and record offset + # of byte immediately following each newline (start of new record) + newlines = [] + pos = 0 + length = len(data) + while pos < length: + found = data.find(b'\n', pos) + if found >= 0: + # We actually store the byte offset to the start + # of the record that would follow the newline char. + newlines.append(offset + found + 1) + + # Update our buffer position and keep scanning. + pos = found + 1 + else: + # No newlines in the remainder of the buffer + break + + # Count number of newline chars we found, + # and compute sum and offset of newlines across ranks. + numrecs = len(newlines) + reccount = self.distctx.sum(numrecs) + recoffset = self.distctx.exscan(numrecs) + + # Store offsets as int64 + vals = np.array(newlines, dtype=np.int64) + + # Write offsets into temporary index file + pos = (recstart + recoffset) * 8 + ftmp.seek(pos) + ftmp.write(vals.tobytes(order='C')) + + # Bump up to next slot in the temporary index file. + recstart += reccount + + # Move on to the next section of the JSON file. + curpos += bufsize * numranks + + # this can take a while, so print progress messages if enabled + if rank == 0 and self.progress > 0.0: + time_now = time.time() + if time_now > time_next: + time_next = time_now + self.progress + elapsed = time_now - time_start + percent = curpos * 100.0 / filesize if filesize > 0 else 0.0 + byterate = curpos / elapsed / (1024.0 * 1024.0) if elapsed > 0.0 else 0.0 + remaining = (100.0 - percent) * elapsed / percent if percent > 0.0 else 0.0 + print(f"Scanned {curpos} of {filesize} bytes ({percent:0.2f}%) in {int(elapsed)} secs, " + f"{byterate:0.3f} MB/s, {int(remaining)} secs left ...", flush=True) + + # Wait for all ranks to close the file. + self.distctx.barrier() + + # Create the actual index file. + with self.distctx.open(filename_idx) as fidx: + # Rank 0 writes the index file header. + if self.distctx.rank == 0: + fidx.write(b'INDXJSON') # use 8-byte magic value of "INDXJSON" + fidx.write(struct.pack(">Q", 1)) # file format version number in network byte order + data_offset = 16 + + # We'll read the offsets back from the temporary index file. + with open(filename_tmp, "rb") as ftmp: + # Compute the [start,end) range for this rank within the list of offsets. + start, end = self.get_start_end(recstart) + + # Determine how many records this rank is responsible for. + readcount = end - start + if readcount > 0: + # We'll read all offsets in our portion, + # plus one offset that comes immediately before our section. + readcount += 1 + if start > 0: + pos = (start - 1) * 8 + ftmp.seek(pos) + + # Allocate a buffer and read in the offsets + recoffsets = np.zeros(readcount, dtype=np.int64) + if start > 0: + ftmp.readinto(recoffsets) + else: + # We leave the first entry as 0 on rank 0 + ftmp.readinto(recoffsets[1:]) + + # Compute length of each record by computing the difference + # between consecutive offset values. Also ignore the first + # offset when writing our (offset,length) pairs. + lengths = recoffsets[1:] - recoffsets[:-1] + offsets = recoffsets[:-1] + + # Prepare list of (offset,length) pairs for writing. + # Store are int64 types. + vals = np.zeros((readcount - 1, 2), dtype=np.int64) + vals[:,0] = offsets + vals[:,1] = lengths + + # Write our portion of the index values. + # We write values to the index file in network byte order, + # so that the file can be read correctly on any system. + if readcount > 0: + pos = data_offset + start * 16 + fidx.seek(pos) + fidx.write(vals.astype(">i8").tobytes(order='C')) + + # Wait for all ranks to finish writing to the index file. + self.distctx.barrier() + + # Can now delete the temporary index file. + if rank == 0: + os.remove(filename_tmp) + + # Wait for everyone again and record how long it took. + self.distctx.barrier() + time_end = time.time() + self.time_index = time_end - time_start + + # if progress messages are enabled, print a summary + if rank == 0 and self.progress > 0.0: + print(f"Indexed '{filename}' in {int(self.time_index)} seconds", flush=True) + + def __str__(self): + return (f"IndexedJSON (\n" + f" file: {self.filename}\n" + f" rows: {self.numsamples}\n" + f")") + + def __len__(self): + """Return number of samples (lines) in the JSON file.""" + return self.numsamples + + def __getitem__(self, idx): + """Given a sample id, return the sample as a python object parsed from JSON string.""" + # read offset and length of record from the index + # seek to offset in JSON file and read the record + buf = self.read(idx) + + # convert json record into a python dictionary + try: + #entry = json.loads(buf.decode("utf-8").strip()) + entry = json.loads(buf) + return entry + except: + # TODO: throw exception instead? + return None + + def __get__(self, idx): + return self.getitem(idx) + + def index(self, idx): + """Given an sample id, return (offset, size) tuple of location of sample in the JSON file.""" + assert idx < self.numsamples + + if self.idx_version == 1: + # Version 1 has a 16-byte header followed by a + # list of (offset, length) pairs of uint64 + + # Seek to the right spot in the index file for the given sample id. + header_size = 16 + offset_idx = header_size + idx * 16 + self.fh_idx.seek(offset_idx) + + # Read offset and length of record from the index. + # Values in the index file are stored in network byte order. + vals = np.zeros(2, dtype=">i8") + self.fh_idx.readinto(vals) + offset = vals[0] + size = vals[1] + + return offset, size + + def pread(self, offset, size): + """Read size bytes at the given offset in the JSON file and return as a buffer.""" + # seek to offset in JSON file and read the record + self.fh_json.seek(offset) + buf = self.fh_json.read(size) + return buf + + def read(self, idx): + """Given a sample id, read sample from the file and return as a buffer.""" + # read offset and length of record from the index + # seek to offset in JSON file and read the record + offset, size = self.index(idx) + return self.pread(offset, size) + + def size(self, idx): + """Given a sample id, return the number of bytes of that sample as stored in the JSON file.""" + offset, size = self.index(idx) + return size diff --git a/tools/preprocess_data_dist.py b/tools/preprocess_data_dist.py index bd12b70fc..afb349a6f 100644 --- a/tools/preprocess_data_dist.py +++ b/tools/preprocess_data_dist.py @@ -20,7 +20,7 @@ from datasets import load_dataset dset = load_dataset('openwebtext', split='train') -The implementation can use `mpi4py` or `torch.distributed` for node communication, +The implementation uses `torch.distributed` for inter-process communication, and it assumes that files are written to a global file system, such that one process can read a file written by another process. @@ -34,17 +34,14 @@ To run: -mpiexec -np 320 python preprocess_data_dist.py \ - --input openwebtext \ - --count 1_000_000 \ - --scratch /dev/shm \ - --shuffle \ - --seed 100 \ - --output-prefix openwebtext-bert \ - --vocab bert-large-uncased-vocab.txt \ - --dataset-impl mmap \ - --tokenizer-type BertWordPieceLowerCase \ - --split-sentences +python -m torch.distributed.launch --nproc_per_node 40 --nnodes 8 \ + preprocess_data_dist.py \ + --input openwebtext \ + --output-prefix openwebtext-bert \ + --vocab bert-large-uncased-vocab.txt \ + --dataset-impl mmap \ + --tokenizer-type BertWordPieceLowerCase \ + --split-sentences """ import argparse @@ -68,6 +65,8 @@ from datasets import config, logging, load_dataset from datasets.utils.file_utils import OfflineModeIsEnabled +from indexed_json import IndexedJSON + from megatron.tokenizer import build_tokenizer from megatron.data.indexed_dataset import data_file_path, index_file_path, make_builder, best_fitting_dtype, gather_files_dist from megatron.data.distdata import DistData @@ -176,14 +175,12 @@ def get_args(): help='Select torch.distributed backend.') group.add_argument('--local_rank', type=int, default=None, help='Local rank of calling process on its node (from torch.distributed.launch).') - group.add_argument('--mpi4py', action='store_true', - help='Assume script has been launched as an MPI job, and use mpi4py for communication.') group.add_argument('--merge', type=str, default='parallel', choices=['parallel', 'serial', 'both'], help=('Method to merge intermediate per-rank files into the final data files. ' 'With "parallel", each rank writes directly to the final files, ' 'while rank 0 copies data from all per-rank files with "serial". ' 'A parallel merge can be faster, but for correctness, it requires the underlying file system ' - 'to support parallel write operations to a file shared among multiple processes. ' + 'to support parallel write operations to a file that is shared among multiple processes. ' 'One can choose "both" for testing purposes, in which case the final files written ' 'by the parallel method are given an additional ".par" extension.')) group.add_argument('--scratch', type=str, default=None, @@ -195,20 +192,18 @@ def get_args(): args = parser.parse_args() args.keep_empty = False - # some default/dummy values for the tokenizer - args.rank = 0 - args.make_vocab_size_divisible_by = 128 - args.tensor_model_parallel_size = 1 - args.vocab_extra_ids = 0 - # initialize our distributed environment - # use mpi4py instead of torch.distributed if requested - args.distctx = DistData(use_mpi4py=args.mpi4py, backend=args.torch_backend) + args.distctx = DistData(backend=args.torch_backend) # some functions like build_tokenizer use args.rank to filter stdout messages args.rank = args.distctx.rank args.numranks = args.distctx.numranks + # some default/dummy values for the tokenizer + args.make_vocab_size_divisible_by = 128 + args.tensor_model_parallel_size = 1 + args.vocab_extra_ids = 0 + if args.tokenizer_type.lower().startswith('bert'): if not args.split_sentences: if args.rank == 0: @@ -221,7 +216,7 @@ def get_args(): # TODO: perhaps more user friendly to disable scratch and print a warning? # check that serial merge is not attempted with scratch if args.scratch is not None and args.merge != 'parallel': - assert False, "The --scratch option is only valid with --merge=parallel" + raise ValueError("The --scratch option is only valid with --merge=parallel") return args @@ -252,7 +247,6 @@ def load_dset(args): # Load the specified HuggingFace dataset. # Give rank 0 a head start in case the dataset is not already cached. - success = True err = None dsetname = args.input if args.rank == 0: @@ -265,17 +259,13 @@ def load_dset(args): msgerr(f" from datasets import load_dataset") msgerr(f" dset = load_dataset('{dsetname}')") msgerr(f"Alternatively, one can force this script to download by setting $HF_DATASETS_OFFLINE=0", flush=True) - success = False err = e except Exception as e: - msgerr("Unexpected error:", sys.exc_info()[0], flush=True) - success = False + msgerr(f"Unexpected error: {sys.exc_info()[0]}", flush=True) err = e # determine whether rank 0 succeeded in loading the dataset - success = args.distctx.alltrue(success) - if not success: - return None, err + args.distctx.allraise_if(err) # Rank 0 succeeded, attempt to load dataset on all other ranks. # This should load from cache now. @@ -284,22 +274,17 @@ def load_dset(args): dset = load_dataset(dsetname, split=args.split, keep_in_memory=None) except Exception as e: # this print might be noisy, but better than nothing - msgerr("Unexpected error:", sys.exc_info()[0], flush=True) - success = False + msgerr(f"Unexpected error: {sys.exc_info()[0]}", flush=True) err = e # verify that all ranks loaded the dataset - success = args.distctx.alltrue(success) - if not success: - if args.rank == 0: - msgerr(f"At least one process failed to load {dsetname}", flush=True) - return None, err + args.distctx.allraise_if(err) time_end = time.time() if args.rank == 0: msg(f"Seconds to load dataset: {time_end - time_start}", flush=True) - return dset, err + return dset def get_num_samples(args, dset_size): """Given a dataset size and optional count argument, return number of samples to process.""" @@ -335,13 +320,10 @@ def select_sample_list(args, dset_size): # get a list of the number of elements each rank will hold counts = get_proc_counts(num_samples, args.numranks) - # allocate space to hold its portion of the list - idx = np.zeros(counts[args.rank], np.int64) - # scatter sample index values from rank 0 to all procs # based on distribution defined in counts list time_bcast = time.time() - args.distctx.scatterv_(idxlist, counts, idx, root=0) + idx = args.distctx.scatterv_(idxlist, counts, root=0) args.distctx.barrier() time_end = time.time() @@ -384,7 +366,6 @@ def rank_files_write(args, dset, idx, encoder): dset_stats = np.zeros(3, dtype=np.int64) # docs, sentences, bytes # we'll set this to false on any problem - success = True err = None times = np.zeros(3, dtype=np.float32) # read, tokenize, write try: @@ -451,7 +432,6 @@ def rank_files_write(args, dset, idx, encoder): del builders[key] # file closed in __del__ except Exception as e: # caught an exception, assume our file is invalid - success = False err = e # In case rank 0 finishes early and stops printing progress messages, @@ -483,32 +463,32 @@ def rank_files_write(args, dset, idx, encoder): msg(f" Total encode seconds {times[1]}, {secs_encode_per_sample} sec/sample") msg(f" Total write seconds {times[2]}, {secs_write_per_sample} sec/sample") - # allreduce to check whether all ranks wrote their part successfully - success = args.distctx.alltrue(success) - return success, err + # check whether all ranks wrote their part successfully + args.distctx.allraise_if(err) def rank_files_merge_parallel(args): """Each process directly writes its portion of the data from its per-rank file into the final file.""" merge_start = time.time() numbytes = np.zeros(1, dtype=np.int64) for key in args.columns: + # merge the per-rank file from each process into a single shared file filemain = get_filename(args, key) filerank = get_filename(args, key, args.rank) gather_files_dist(filemain, [filerank], args.distctx) - # total up bytes read in merge - binfile = data_file_path(filerank) - idxfile = index_file_path(filerank) - numbytes[0] += os.stat(binfile)[stat.ST_SIZE] - numbytes[0] += os.stat(idxfile)[stat.ST_SIZE] + # total up bytes read during the merge + binfilerank = data_file_path(filerank) + idxfilerank = index_file_path(filerank) + numbytes[0] += os.stat(binfilerank)[stat.ST_SIZE] + numbytes[0] += os.stat(idxfilerank)[stat.ST_SIZE] # If user want to use both a parallel and serial merge (for testing), # rename the parallel output files so that the serial merge does not clobber them. if args.merge == 'both' and args.rank == 0: - binfile = data_file_path(filemain) - idxfile = index_file_path(filemain) - os.rename(binfile, binfile + ".par") - os.rename(idxfile, idxfile + ".par") + binfilemain = data_file_path(filemain) + idxfilemain = index_file_path(filemain) + os.rename(binfilemain, binfilemain + ".par") + os.rename(idxfilemain, idxfilemain + ".par") # Total up number of bytes read across all ranks, # and wait on all ranks before stopping the timer. @@ -546,8 +526,6 @@ def rank_files_merge_serial(args): for rank in range(args.numranks): for key in args.columns: infile = get_filename(args, key, rank) - -# msg(f"Merging file {infile}", flush=True) builders[key].merge_file_(infile) # sum up the number of merged bytes @@ -610,11 +588,12 @@ def main(): startup_start = time.time() # load the dataset - dset, err = load_dset(args) - if dset is None: - if err is not None: - raise err - return + if args.input.endswith(".jsonl"): + # assume file is JSONL format + dset = IndexedJSON(args.input, args.mpi_comm) + else: + # otherwise load HuggingFace dataset + dset = load_dset(args) if args.rank == 0: print(dset) msg(f"Processing features: {args.columns}") @@ -636,21 +615,21 @@ def main(): if args.rank == 0: msg(f"Seconds to startup: {startup_end - startup_start}") - # have each rank write its file, returns False if any rank had a problem - success, err = rank_files_write(args, dset, idx, encoder) - if not success: + # have each rank write its file, + # all ranks should raise an exception if any rank has a problem + try: + rank_files_write(args, dset, idx, encoder) + except Exception as e: + # If any process fails, we skip the merge since the resulting file would be invalid. + # We still delete files to clean up, since those might be invalid anyway. if args.rank == 0: - # If any process fails, we skip the merge since the resulting file would be invalid. - # We still delete files to clean up, since those might be invalid anyway. msgerr(f"At least one process failed to write its file, skipping merge and cleaning up", flush=True) # delete per-rank files, do this even on error rank_files_delete(args) - # raise exception caught during write phase - if err is not None: - raise err - return + # re-raise exception caught during write phase + raise e # all ranks were successful writing their file, merge them into one rank_files_merge(args)