Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
269af4e
add parallel merge using mpi
adammoody Aug 9, 2021
9ba081b
handle case where some ranks might have 0 items
adammoody Aug 10, 2021
d29a702
add inclusive scan prefix sum
adammoody Aug 11, 2021
ed49713
report more timing info
adammoody Aug 11, 2021
e94f2a0
Update megatron/data/indexed_dataset.py
adammoody Aug 12, 2021
687ff32
Update megatron/data/indexed_dataset.py
adammoody Aug 12, 2021
af59545
rename total size variable for clarity
adammoody Aug 12, 2021
4f648a0
move translation to bin/idx file names a level deeper
adammoody Aug 13, 2021
9f2ba6a
parallel merge for cached dataset
adammoody Aug 13, 2021
72d6c9c
add alltrue function
adammoody Aug 13, 2021
8b67bec
move collectives to new distdata class, add torch.distributed
adammoody Aug 14, 2021
3eca1f3
drop unused prefix_sum function
adammoody Aug 14, 2021
a691b48
allow ranks to pass a list of files to be merged
adammoody Aug 15, 2021
e4a34e2
check that input dataset files exist
adammoody Aug 15, 2021
8b168ca
fix: using wrong doc_idx list for mmap
adammoody Aug 16, 2021
7a02693
move init dist and collectives to distdata class
adammoody Aug 16, 2021
eca2940
add --merge option, move parallel/serial to their own functions
adammoody Aug 16, 2021
b14491d
Merge branch 'main' into pmerge
adammoody Aug 16, 2021
ec11281
Update megatron/data/distdata.py
adammoody Aug 16, 2021
354d13b
Update megatron/data/indexed_dataset.py
adammoody Aug 16, 2021
2dc3f7a
Update megatron/data/indexed_dataset.py
adammoody Aug 16, 2021
980e904
Update megatron/data/indexed_dataset.py
adammoody Aug 16, 2021
ebd20a6
Update megatron/data/indexed_dataset.py
adammoody Aug 16, 2021
69b2f49
Update megatron/data/indexed_dataset.py
adammoody Aug 16, 2021
50de06a
Update megatron/data/indexed_dataset.py
adammoody Aug 16, 2021
af290ad
drop extraneous numpy tolist calls
adammoody Aug 16, 2021
4b58c74
rename self.MPI to mpi4py
adammoody Aug 16, 2021
71a2fdc
handle case where no ranks have elements in their file
adammoody Aug 16, 2021
73d3a24
rename tokenize_start to time_start
adammoody Aug 16, 2021
b9e69be
drop unrelated comment in distdata.min
adammoody Aug 16, 2021
da615c6
add comment why pointers_shift is not None and add assert
adammoody Aug 16, 2021
c42f41f
note why pointers uses sizes count and offset values
adammoody Aug 16, 2021
a3a7d53
can just rely on rank 0 for the leading 0 element
adammoody Aug 17, 2021
163310a
add write_list function
adammoody Aug 17, 2021
01b2be0
determine element size
adammoody Aug 17, 2021
4b6e8ff
add checks for consistent element_size values
adammoody Aug 17, 2021
ea08555
check that at least one rank has a file to merge
adammoody Aug 17, 2021
2524fce
assert that torch backend is gloo or mpi
adammoody Aug 17, 2021
ca14d48
add collectives for assert and raise
adammoody Aug 17, 2021
d482f36
rename to allassert and allraise_if
adammoody Aug 17, 2021
28d76f5
check dtype instead of element_size
adammoody Aug 17, 2021
f706108
add uint32 to element_sizes table
adammoody Aug 17, 2021
f122883
infer dtype from files being merged
adammoody Aug 17, 2021
57c012e
add write_header function to indexed dataset classes
adammoody Aug 17, 2021
eed8327
call write_header internally from IndexedDataset classes
adammoody Aug 17, 2021
a75cfc2
return number of bytes written from write calls
adammoody Aug 17, 2021
afcfcf9
Merge branch 'main' into pmerge
adammoody Aug 17, 2021
74b733a
move scatterv to distdata class
adammoody Aug 17, 2021
dadb51b
add functions to format status and error messages
adammoody Aug 17, 2021
a2f8fa0
defer merge_files_dist to future PR
adammoody Aug 17, 2021
39e6cd7
open files using with, refresh comments
adammoody Aug 18, 2021
2a29d99
rely on default torch datatypes
adammoody Aug 18, 2021
d6fa895
fix some status messages from preprocess script
adammoody Aug 18, 2021
1216c0a
fix: exclusive scan computing pointers list
adammoody Aug 18, 2021
a64d3da
Merge branch 'pointerfix' into pmerge
adammoody Aug 18, 2021
fde439e
fix: exclusive scan to compute mmap pointers list
adammoody Aug 18, 2021
ba14351
note about seek
adammoody Aug 19, 2021
852fdd0
rename preprocess_dataset_mpi.py to preprocess_data_dist.py
adammoody Aug 19, 2021
61f4b46
update usage comments at top of script
adammoody Aug 19, 2021
22400f3
restore commented print_rank_0 statements
adammoody Aug 19, 2021
5cfcb95
restore status message in mmap merge_file_
adammoody Aug 19, 2021
74c4883
drop mpi4py, sad :(
adammoody Aug 19, 2021
373e514
Merge branch 'main' into pmerge
adammoody Aug 19, 2021
78ab715
add test case for parallel merge
adammoody Aug 19, 2021
002b403
add preprocess_data_dist test for serial merge
adammoody Aug 19, 2021
ba763f7
improve error handling
adammoody Aug 20, 2021
fa11159
refactor get_pointers code
adammoody Aug 20, 2021
7e53fd3
bug fix in exscan
adammoody Aug 20, 2021
53df36f
further refactor get_pointers
adammoody Aug 20, 2021
c43348f
move exscan collective for pointers outside of try block
adammoody Aug 20, 2021
81c21dd
clarify some comments
adammoody Aug 20, 2021
adee502
include string 1k in name of test files
adammoody Aug 20, 2021
13ae421
use temporary file for index
adammoody Aug 20, 2021
f3e1b1d
fix: implement scatterv from torch.distributed.scatter
adammoody Aug 23, 2021
42962e1
switch to pad method in torch.nn.functional
adammoody Aug 25, 2021
9a2f383
return data received in scatterv as new tensor
adammoody Aug 25, 2021
15b7603
raise exception if conflicting scratch and merge options
adammoody Aug 25, 2021
4adaddd
use allraise method from distdata in preprocess_data_dist
adammoody Aug 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions megatron/data/distdata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import os
import numpy as np

import torch
import torch.nn.functional as F
import torch.distributed as dist

class DistDataError(Exception):
"""Defines an empty exception to throw when some other rank hit a real exception."""
pass

class DistData(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that's an awesome abstraction for multinode, especially in the case where we might want to support both torch.distributed and mpi4py. Though as I'm thinking about it, supporting mpi4py will become costly very soon, and might not bring much to the table as we might need to implement all improvements in two frameworks that essentially can use the same backend mpi. As you mentioned in a comment we should be able to do everything using torch.distributed, so let's remove mpi4py and come back to it the day torch.distributed isn't enough.

Sorry for my mistake, I thought it would be interesting to support both case, but I feel it ends up being a burden here where we need to create a higher level abstraction for not much.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it would require some maintenance. I'll can yank out the MPI just before we merge the PR.

def __init__(self, backend='gloo'):
assert backend in ['gloo', 'mpi'], f"torch.distributed backend '{backend}' is not supported, valid options are 'gloo' or 'mpi'"

dist.init_process_group(backend, init_method="env://")

# lookup our process rank and the group size
self.rank = dist.get_rank()
self.numranks = dist.get_world_size()

def allassert(self, cond, msg):
"""Check that cond is True on all ranks, assert with msg everywhere if not.

To prevent deadlocks in cases where an assertion might only fail on one rank,
this executes an allreduce to ensure that if any rank finds that an assertion
has been violated, all ranks fail an assertion check.
The condition must be true on all ranks for this not to assert.
"""
alltrue = self.alltrue(cond)
assert alltrue, msg
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You rely on the fact that all rank will either pass or raise due to alltrue right? I'd say it's non trivial and should be commented? At least it explains why it has a different pattern to other erro handling methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some more explanation here and for allraise_if(err) which is similar.


def allraise_if(self, err):
"""Raise exception if err is not None on any rank.

Similarly to allassert, this raises an exception on all ranks if err
is set to an exception on any rank. Rank(s) where err is not None
re-raise err as exception, and ranks where err is None raise DistDataError.
Thus all ranks raise an exception if any rank has an active exception,
which helps avoid deadlocks in cases where an exception may be raised
on a subset of ranks.
"""
alltrue = self.alltrue(err is None)
if not alltrue:
# At least one rank raised an exception.
# Re-raise the actual exception if this rank threw one.
if err is not None:
raise err

# TODO: is there a better exception to use here?
# On other ranks, raise an "empty" exception to indicate
# that we're only failing because someone else did.
raise DistDataError

def barrier(self):
"""Globally synchronize all processes"""
dist.barrier()

def bcast(self, val, root):
"""Broadcast a scalar value from root to all ranks"""
vals = [val]
dist.broadcast_object_list(vals, src=root)
return vals[0]

def scatterv_(self, invals: np.array, counts: list, root:int=0):
"""Scatter int64 values from invals according to counts array, return received portion in a new tensor"""

self.allassert(len(counts) == self.numranks,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's really nice!!! I like the self.assert system!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. It is handy. It does add an allreduce, so it's not as cheap as a normal assertion. Also, it could use some work in that it would be nice to single out the ranks where the assert has failed vs those where the assert passed but are asserting anyway to keep in sync.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think there's a -O flag that allows to bypass all asserts. Maybe there's a decorator so that we can always ignore self.allassert when running for really long files?

Copy link
Member

@thomasw21 thomasw21 Aug 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could have a decorator for this, basically prevent running when in optimised mode.

def run_in_debug_only(default_value):
    def _run_in_debug_only(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            if __debug__:
                 return func(*args, **kwargs)
            else:
                 return default_value
    return _run_int_debug_only

And then you'd be able to use things like

@run_in_debug_only(default_value = True)
def allassert():
   ....

This might be overkill though, but then you'd have something that's very similar to assert, ie something that's removed when you use python -O script.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that's a neat trick. I can take a look.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put that in future PR. For me the PR introduces the feature, we could always improve on it later on.

f"Length of counts list {len(counts)} does not match number of ranks {self.numranks}")

# Define list of tensors to scatter on the root.
# torch.distributed.scatter requires each tensor to be the same shape,
# so find the max size across all count values and pad.
max_size = max(counts)
scatterlist = None
if self.rank == root:
slices = list(torch.split(torch.from_numpy(invals), counts))
scatterlist = [F.pad(s, (0, max_size - len(s))) for s in slices]

# Receive a tensor of the max count size from the root,
# then copy values into output numpy array, which may be smaller.
recvtensor = torch.zeros(max_size, dtype=torch.int64)
dist.scatter(recvtensor, scatterlist, src=root)
Copy link
Member

@thomasw21 thomasw21 Aug 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about scatter_object_list? I tried it and it seems to work nicely, though I don't know what's the cost using that method vs padding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, interesting. So I seem to be using an old enough torch.distributed that I don't have scatter_object_list:

AttributeError: module 'torch.distributed' has no attribute 'scatter_object_list'

It looks like it was added about 9 months ago in this commit:

pytorch/pytorch@02d89f9

I can also see that it internally is just calling broadcast and scatter a couple of times.

After seeing that, I think our pad method is probably the best way to go after all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah let's go with padding.

return recvtensor[:counts[self.rank]]

def alltrue(self, val):
"""Returns True if all procs input True, False otherwise"""
# torch.dist does not support reductions with bool types
# so we cast to int and cast the result back to bool
tensor = torch.tensor([int(val)], dtype=torch.int32)
dist.all_reduce(tensor, op=dist.ReduceOp.BAND)
return bool(tensor[0])

def sum(self, val):
"""Compute sum of a scalar val, and return total on all ranks."""
tensor = torch.tensor([val])
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
return tensor[0]

def exscan(self, val: int):
"""Compute prefix sum (exclusive scan) of int64 val, and return offset of each rank."""
# torch.distributed doesn't have a scan, so fallback to allreduce
tensor = torch.zeros(self.numranks, dtype=torch.int64)
tensor[self.rank:] = val
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
return int(tensor[self.rank]) - val

def min(self, val):
"""Return minimum of scalar val to all ranks."""
tensor = torch.tensor([val])
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
return tensor[0]

def minrank(self, cond):
"""Find first rank whose condition is True, return that rank if any, None otherwise."""
minrank = self.numranks
if cond:
minrank = self.rank
minrank = self.min(minrank)

if minrank < self.numranks:
return minrank
return None

def bcast_first(self, val):
"""Broadcast val from first rank where it is not None, return val if any, None otherwise"""
# Find the first rank with a valid value.
minrank = self.minrank(val is not None)

# If there is no rank with a valid value, return None
if minrank is None:
return None

# Otherwise broadcast the value from the first valid rank.
val = self.bcast(val, root=minrank)
return val

def all_sum_(self, vals: np.array):
"""Sums values in numpy array vals element-wise and update vals in place with final result on all ranks"""
# Builds torch.tensor with from_numpy to use same underlying memory as numpy array.
tensor = torch.from_numpy(vals)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

def open(self, filename, truncate=None):
"""Create, truncate, and open a file shared by all ranks."""

# Don't truncate existing file until all ranks reach this point
self.barrier()

# We'll capture any exception in this variable
err = None

# Rank 0 creates and truncates file.
if self.rank == 0:
try:
f = open(filename, 'wb')

# Some file systems like GPFS deliver faster write speed
# if the file size is known before data is written to the file.
if truncate is not None:
f.truncate(truncate)

except Exception as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to catch BaseException also no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, added that in as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I take that back. I guess I don't know. I've seen some articles that say most times you don't want to catch BaseException. Is that something we should do?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well depends on what the error we're trying to catch? What would open(...,"wb") really throw?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take a closer look.

err = e

# Verify that rank 0 created the file
self.allraise_if(err)

# Wait for rank 0 to open (and truncate) file,
# then have all ranks open file for writing.
if self.rank != 0:
try:
f = open(filename, 'r+b')
except Exception as e:
err = e

# Verify that all ranks successfully opened the file
self.allraise_if(err)

return f

def remove(self, filename):
"""Remove a shared file."""

# Don't remove the file until all are ready
self.barrier()

# We'll capture any exception in this variable
err = None

# Rank 0 removes the file if it exists.
if self.rank == 0:
try:
if os.path.exists(filename):
os.remove(filename)
except Exception as e:
err = e

# Verify that rank 0 successfully removed the file.
self.allraise_if(err)

def rename(self, srcfile, destfile):
"""Rename a shared file."""

# Don't rename until all are ready
self.barrier()

# We'll capture any exception in this variable
err = None

# Rank 0 renames the file.
if self.rank == 0:
try:
if os.path.exists(srcfile):
os.rename(srcfile, destfile)
except Exception as e:
err = e

# Verify that the rename succeeded
self.allraise_if(err)
Loading