-
Notifications
You must be signed in to change notification settings - Fork 228
distributed merge of per-rank Megatron data files #55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
269af4e
9ba081b
d29a702
ed49713
e94f2a0
687ff32
af59545
4f648a0
9f2ba6a
72d6c9c
8b67bec
3eca1f3
a691b48
e4a34e2
8b168ca
7a02693
eca2940
b14491d
ec11281
354d13b
2dc3f7a
980e904
ebd20a6
69b2f49
50de06a
af290ad
4b58c74
71a2fdc
73d3a24
b9e69be
da615c6
c42f41f
a3a7d53
163310a
01b2be0
4b6e8ff
ea08555
2524fce
ca14d48
d482f36
28d76f5
f706108
f122883
57c012e
eed8327
a75cfc2
afcfcf9
74b733a
dadb51b
a2f8fa0
39e6cd7
2a29d99
d6fa895
1216c0a
a64d3da
fde439e
ba14351
852fdd0
61f4b46
22400f3
5cfcb95
74c4883
373e514
78ab715
002b403
ba763f7
fa11159
7e53fd3
53df36f
c43348f
81c21dd
adee502
13ae421
f3e1b1d
42962e1
9a2f383
15b7603
4adaddd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,220 @@ | ||
| import os | ||
| import numpy as np | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| import torch.distributed as dist | ||
|
|
||
| class DistDataError(Exception): | ||
| """Defines an empty exception to throw when some other rank hit a real exception.""" | ||
| pass | ||
|
|
||
| class DistData(object): | ||
| def __init__(self, backend='gloo'): | ||
| assert backend in ['gloo', 'mpi'], f"torch.distributed backend '{backend}' is not supported, valid options are 'gloo' or 'mpi'" | ||
|
|
||
| dist.init_process_group(backend, init_method="env://") | ||
|
|
||
| # lookup our process rank and the group size | ||
| self.rank = dist.get_rank() | ||
| self.numranks = dist.get_world_size() | ||
|
|
||
| def allassert(self, cond, msg): | ||
| """Check that cond is True on all ranks, assert with msg everywhere if not. | ||
|
|
||
| To prevent deadlocks in cases where an assertion might only fail on one rank, | ||
| this executes an allreduce to ensure that if any rank finds that an assertion | ||
| has been violated, all ranks fail an assertion check. | ||
| The condition must be true on all ranks for this not to assert. | ||
| """ | ||
| alltrue = self.alltrue(cond) | ||
| assert alltrue, msg | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You rely on the fact that all rank will either pass or raise due to alltrue right? I'd say it's non trivial and should be commented? At least it explains why it has a different pattern to other erro handling methods.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added some more explanation here and for |
||
|
|
||
| def allraise_if(self, err): | ||
| """Raise exception if err is not None on any rank. | ||
|
|
||
| Similarly to allassert, this raises an exception on all ranks if err | ||
| is set to an exception on any rank. Rank(s) where err is not None | ||
| re-raise err as exception, and ranks where err is None raise DistDataError. | ||
| Thus all ranks raise an exception if any rank has an active exception, | ||
| which helps avoid deadlocks in cases where an exception may be raised | ||
| on a subset of ranks. | ||
| """ | ||
| alltrue = self.alltrue(err is None) | ||
| if not alltrue: | ||
| # At least one rank raised an exception. | ||
| # Re-raise the actual exception if this rank threw one. | ||
| if err is not None: | ||
| raise err | ||
|
|
||
| # TODO: is there a better exception to use here? | ||
| # On other ranks, raise an "empty" exception to indicate | ||
| # that we're only failing because someone else did. | ||
| raise DistDataError | ||
|
|
||
| def barrier(self): | ||
| """Globally synchronize all processes""" | ||
| dist.barrier() | ||
|
|
||
| def bcast(self, val, root): | ||
| """Broadcast a scalar value from root to all ranks""" | ||
| vals = [val] | ||
| dist.broadcast_object_list(vals, src=root) | ||
| return vals[0] | ||
|
|
||
| def scatterv_(self, invals: np.array, counts: list, root:int=0): | ||
| """Scatter int64 values from invals according to counts array, return received portion in a new tensor""" | ||
|
|
||
| self.allassert(len(counts) == self.numranks, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's really nice!!! I like the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. It is handy. It does add an allreduce, so it's not as cheap as a normal assertion. Also, it could use some work in that it would be nice to single out the ranks where the assert has failed vs those where the assert passed but are asserting anyway to keep in sync.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I think there's a
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could have a decorator for this, basically prevent running when in optimised mode. And then you'd be able to use things like @run_in_debug_only(default_value = True)
def allassert():
....This might be overkill though, but then you'd have something that's very similar to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, that's a neat trick. I can take a look.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's put that in future PR. For me the PR introduces the feature, we could always improve on it later on. |
||
| f"Length of counts list {len(counts)} does not match number of ranks {self.numranks}") | ||
|
|
||
| # Define list of tensors to scatter on the root. | ||
| # torch.distributed.scatter requires each tensor to be the same shape, | ||
| # so find the max size across all count values and pad. | ||
| max_size = max(counts) | ||
| scatterlist = None | ||
| if self.rank == root: | ||
| slices = list(torch.split(torch.from_numpy(invals), counts)) | ||
| scatterlist = [F.pad(s, (0, max_size - len(s))) for s in slices] | ||
|
|
||
| # Receive a tensor of the max count size from the root, | ||
| # then copy values into output numpy array, which may be smaller. | ||
| recvtensor = torch.zeros(max_size, dtype=torch.int64) | ||
| dist.scatter(recvtensor, scatterlist, src=root) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, interesting. So I seem to be using an old enough It looks like it was added about 9 months ago in this commit: I can also see that it internally is just calling broadcast and scatter a couple of times. After seeing that, I think our pad method is probably the best way to go after all.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah let's go with padding. |
||
| return recvtensor[:counts[self.rank]] | ||
|
|
||
| def alltrue(self, val): | ||
| """Returns True if all procs input True, False otherwise""" | ||
| # torch.dist does not support reductions with bool types | ||
| # so we cast to int and cast the result back to bool | ||
| tensor = torch.tensor([int(val)], dtype=torch.int32) | ||
| dist.all_reduce(tensor, op=dist.ReduceOp.BAND) | ||
| return bool(tensor[0]) | ||
|
|
||
| def sum(self, val): | ||
| """Compute sum of a scalar val, and return total on all ranks.""" | ||
| tensor = torch.tensor([val]) | ||
| dist.all_reduce(tensor, op=dist.ReduceOp.SUM) | ||
| return tensor[0] | ||
|
|
||
| def exscan(self, val: int): | ||
| """Compute prefix sum (exclusive scan) of int64 val, and return offset of each rank.""" | ||
| # torch.distributed doesn't have a scan, so fallback to allreduce | ||
| tensor = torch.zeros(self.numranks, dtype=torch.int64) | ||
| tensor[self.rank:] = val | ||
| dist.all_reduce(tensor, op=dist.ReduceOp.SUM) | ||
| return int(tensor[self.rank]) - val | ||
|
|
||
| def min(self, val): | ||
| """Return minimum of scalar val to all ranks.""" | ||
| tensor = torch.tensor([val]) | ||
| dist.all_reduce(tensor, op=dist.ReduceOp.MIN) | ||
| return tensor[0] | ||
|
|
||
| def minrank(self, cond): | ||
| """Find first rank whose condition is True, return that rank if any, None otherwise.""" | ||
| minrank = self.numranks | ||
| if cond: | ||
| minrank = self.rank | ||
| minrank = self.min(minrank) | ||
|
|
||
| if minrank < self.numranks: | ||
| return minrank | ||
| return None | ||
|
|
||
| def bcast_first(self, val): | ||
| """Broadcast val from first rank where it is not None, return val if any, None otherwise""" | ||
| # Find the first rank with a valid value. | ||
| minrank = self.minrank(val is not None) | ||
|
|
||
| # If there is no rank with a valid value, return None | ||
| if minrank is None: | ||
| return None | ||
|
|
||
| # Otherwise broadcast the value from the first valid rank. | ||
| val = self.bcast(val, root=minrank) | ||
| return val | ||
|
|
||
| def all_sum_(self, vals: np.array): | ||
| """Sums values in numpy array vals element-wise and update vals in place with final result on all ranks""" | ||
| # Builds torch.tensor with from_numpy to use same underlying memory as numpy array. | ||
| tensor = torch.from_numpy(vals) | ||
| dist.all_reduce(tensor, op=dist.ReduceOp.SUM) | ||
|
|
||
| def open(self, filename, truncate=None): | ||
| """Create, truncate, and open a file shared by all ranks.""" | ||
|
|
||
| # Don't truncate existing file until all ranks reach this point | ||
| self.barrier() | ||
|
|
||
| # We'll capture any exception in this variable | ||
| err = None | ||
|
|
||
| # Rank 0 creates and truncates file. | ||
| if self.rank == 0: | ||
| try: | ||
| f = open(filename, 'wb') | ||
|
|
||
| # Some file systems like GPFS deliver faster write speed | ||
| # if the file size is known before data is written to the file. | ||
| if truncate is not None: | ||
| f.truncate(truncate) | ||
|
|
||
| except Exception as e: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We might want to catch
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, added that in as well.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, I take that back. I guess I don't know. I've seen some articles that say most times you don't want to catch BaseException. Is that something we should do?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well depends on what the error we're trying to catch? What would open(...,"wb") really throw?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll take a closer look. |
||
| err = e | ||
|
|
||
| # Verify that rank 0 created the file | ||
| self.allraise_if(err) | ||
|
|
||
| # Wait for rank 0 to open (and truncate) file, | ||
| # then have all ranks open file for writing. | ||
| if self.rank != 0: | ||
| try: | ||
| f = open(filename, 'r+b') | ||
| except Exception as e: | ||
| err = e | ||
|
|
||
| # Verify that all ranks successfully opened the file | ||
| self.allraise_if(err) | ||
|
|
||
| return f | ||
|
|
||
| def remove(self, filename): | ||
| """Remove a shared file.""" | ||
|
|
||
| # Don't remove the file until all are ready | ||
| self.barrier() | ||
|
|
||
| # We'll capture any exception in this variable | ||
| err = None | ||
|
|
||
| # Rank 0 removes the file if it exists. | ||
| if self.rank == 0: | ||
| try: | ||
| if os.path.exists(filename): | ||
| os.remove(filename) | ||
| except Exception as e: | ||
| err = e | ||
|
|
||
| # Verify that rank 0 successfully removed the file. | ||
| self.allraise_if(err) | ||
|
|
||
| def rename(self, srcfile, destfile): | ||
| """Rename a shared file.""" | ||
|
|
||
| # Don't rename until all are ready | ||
| self.barrier() | ||
|
|
||
| # We'll capture any exception in this variable | ||
| err = None | ||
|
|
||
| # Rank 0 renames the file. | ||
| if self.rank == 0: | ||
| try: | ||
| if os.path.exists(srcfile): | ||
| os.rename(srcfile, destfile) | ||
| except Exception as e: | ||
| err = e | ||
|
|
||
| # Verify that the rename succeeded | ||
| self.allraise_if(err) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So that's an awesome abstraction for multinode, especially in the case where we might want to support both
torch.distributedandmpi4py. Though as I'm thinking about it, supporting mpi4py will become costly very soon, and might not bring much to the table as we might need to implement all improvements in two frameworks that essentially can use the same backendmpi. As you mentioned in a comment we should be able to do everything usingtorch.distributed, so let's remove mpi4py and come back to it the daytorch.distributedisn't enough.Sorry for my mistake, I thought it would be interesting to support both case, but I feel it ends up being a burden here where we need to create a higher level abstraction for not much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it would require some maintenance. I'll can yank out the MPI just before we merge the PR.