Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Horovod distributed backend #1529

Merged
merged 17 commits into from
Apr 22, 2020
Merged

Conversation

tgaddair
Copy link
Contributor

Fixes #1518.

Make the following change to your Trainer to run on GPU (single or multiple) with Horovod:

trainer = Trainer(distributed_backend='horovod', gpus=1)

Or to run on CPU:

trainer = Trainer(distributed_backend='horovod')

Then the training script can be launched via the horovodrun command-line tool, where the host/GPU allocation is specified:

horovodrun -np 8 -H host1:4,host2:4 python train.py

@pep8speaks
Copy link

pep8speaks commented Apr 19, 2020

Hello @tgaddair! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2020-04-22 21:21:47 UTC

@mergify mergify bot requested a review from a team April 19, 2020 17:17
@@ -219,6 +220,13 @@ def set_distributed_mode(self, distributed_backend):
self.use_ddp = True
self.data_parallel_device_ids = None
self.on_gpu = False
elif distributed_backend == 'horovod':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to be transparent to the user.
can we automate setting this? this way the abstraction doesn’t bleed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(the mpirun thing)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I understand you correctly: is the idea that when running via horovodrun or mpirun, if the user has not specified distributed_backend, then we will automatically set distributed_backend='horovod' here?

We could certainly do that when running with horovodrun + our Gloo backend, as we have special environment variables we can check (HOROVOD_RANK for example). Doing so with mpirun is more tricky, because different MPI implementations have different environment variables. Also, in the future, there might be another distributed backend other than Horovod that uses MPI.

So maybe we could automate it for horovodrun but still require them to set it explicitly for mpirun? (Let me know if I misunderstood your suggestion).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I understand you correctly: is the idea that when running via horovodrun or mpirun, if the user has not specified distributed_backend, then we will automatically set distributed_backend='horovod' here?

Yes!

So maybe we could automate it for horovodrun but still require them to set it explicitly for mpirun? (Let me know if I misunderstood your suggestion).

Let's do this for now (v1) and for v2 maybe we set it explicitely for mpirun? i just don't know enough about mpirun yet, but if mpirun can run any backend then the user should be forced to set it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I added a has_horovodrun() check in distrib_data_parallel.py that checks for Gloo or OpenMPI environment variables set by horovodrun. Also added a test. Let me know if that aligns with what you were thinking.

@mergify mergify bot requested a review from a team April 19, 2020 19:03
@mergify
Copy link
Contributor

mergify bot commented Apr 19, 2020

This pull request is now in conflict... :(

@williamFalcon
Copy link
Contributor

@tgaddair i love this! wondering if we can automate the comment I added so the user can use horovod without remembering anything other than turning on the flag

@mergify
Copy link
Contributor

mergify bot commented Apr 19, 2020

This pull request is now in conflict... :(

1 similar comment
@mergify
Copy link
Contributor

mergify bot commented Apr 19, 2020

This pull request is now in conflict... :(

@Borda Borda added the feature Is an improvement or enhancement label Apr 20, 2020
@Borda Borda self-assigned this Apr 20, 2020
@Borda Borda added this to the 0.7.4 milestone Apr 20, 2020
Borda
Borda previously requested changes Apr 20, 2020
CHANGELOG.md Outdated Show resolved Hide resolved
set_proc_rank(self.proc_rank)

if hvd.rank() != 0:
self.logger = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The errors come from a race condition where different ranks will attempt to mkdir the same directory, leading to an exception being raised on one of the workers. For example, this can happen when creating a SummaryWriter, which is why in Horovod we only do so on rank 0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in lightning we already handle setting loggers, etc only to rank=0 btw

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I updated to set the logger ranks to hvd.rank() instead of deleting them outside of rank 0. Let me know if that makes more sense.

parser.add_argument('--trainer-options', required=True)


def test(trainer_options):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test what?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed for clarity and added a docstring at the top of the file to explain usage.

@@ -0,0 +1,36 @@
import argparse
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this meant to be a (unit)test because by the name it won't be found
why there is data/horovod/ would it rather be tests/models/script_train_horovod.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a docstring at the top for clarity. This script is meant to be executed from test_horovod.py. Reason for this is to test driving the training via horovodrun using multiple parallel worker processes.

tests/models/test_horovod.py Outdated Show resolved Hide resolved
@mergify mergify bot requested review from a team April 20, 2020 06:47

# Horovod: wrap optimizers to perform gradient aggregation via allreduce
self.optimizers = [
hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens when i do:

def configure_optimizers(self):
  return Adam(self.generator.parameters(), Adam(self.discriminator.parameters())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this line here break?

model.named_parameters()

Should we not instead do:

[hvd.DistributedOptimizer(optimizer, named_parameters=opt.named_parameters()) for opt in self.optimizers]

This might be a silly question as I don't know the details of DistributedOptimizer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good catch! This was an oversight on my part. I added a fix, and added a unit test specifically for GAN / multi-optimizers.

@mergify mergify bot requested a review from a team April 20, 2020 20:34
@mergify mergify bot requested a review from a team April 21, 2020 01:29
@codecov
Copy link

codecov bot commented Apr 21, 2020

Codecov Report

Merging #1529 into master will increase coverage by 0%.
The diff coverage is 84%.

@@          Coverage Diff           @@
##           master   #1529   +/-   ##
======================================
  Coverage      89%     89%           
======================================
  Files          68      68           
  Lines        3811    3906   +95     
======================================
+ Hits         3385    3471   +86     
- Misses        426     435    +9     

@tgaddair
Copy link
Contributor Author

Hey @williamFalcon looks like there is an incompatibility between PyTorch Lightning and PyTorch 1.5.0 (released last night) that's causing the CI failures:

FAILED test_cpu.py::test_multi_cpu_model_ddp - _pickle.PicklingError: Can't pickle <class 'torch._C._VariableFunctions'>: it's not the same object as torch._C._VariableFunctions

Is someone on your end looking into this? Happy to file an issue.

@Borda
Copy link
Member

Borda commented Apr 21, 2020

Is someone on your end looking into this? Happy to file an issue.

probably issue is not needed we are already working on it #1552

@tgaddair
Copy link
Contributor Author

tgaddair commented Apr 21, 2020

Is someone on your end looking into this? Happy to file an issue.

probably issue is not needed we are already working on it #1552

Thanks @Borda! Once that lands I'll merge in from master and tests should be good again.

Let me know if there's anything additional feedback you have on this PR.

@mergify
Copy link
Contributor

mergify bot commented Apr 22, 2020

This pull request is now in conflict... :(

@williamFalcon
Copy link
Contributor

@tgaddair fixed on master. want to rebase so we can merge this?
i want to get this in before merging the other PRs since this touches so many parts haha

@williamFalcon williamFalcon added the priority: 0 High priority task label Apr 22, 2020
@tgaddair tgaddair force-pushed the horovod branch 4 times, most recently from 96a190e to a76736a Compare April 22, 2020 15:58
@Borda
Copy link
Member

Borda commented Apr 22, 2020

we do not use tox anymore...

@tgaddair
Copy link
Contributor Author

we do not use tox anymore...

I see, I mistook a failure due to a corrupt pip cache for a tox issue. Is there a way to refresh the pip cache? I just commented out that step for now to get tests to pass, not sure what will happen when I restore that line.

@Borda
Copy link
Member

Borda commented Apr 22, 2020

I see, I mistook a failure due to a corrupt pip cache for a tox issue. Is there a way to refresh the pip cache? I just commented out that step for now to get tests to pass, not sure what will happen when I restore that line.

I have tried dropping cache some time ago and didn't find it...
that is the reason why we shall rather use images then cache :]

@tgaddair
Copy link
Contributor Author

I see, I mistook a failure due to a corrupt pip cache for a tox issue. Is there a way to refresh the pip cache? I just commented out that step for now to get tests to pass, not sure what will happen when I restore that line.

I have tried dropping cache some time ago and didn't find it...
that is the reason why we shall rather use images then cache :]

Docker images would be much better, I agree. Looks like I was able to refresh the cache by running rm -rf ~/.cache/pip, then removing that line and running again. Not very elegant, but looks like it worked (fingers crossed).

@williamFalcon williamFalcon merged commit 7024177 into Lightning-AI:master Apr 22, 2020
@williamFalcon
Copy link
Contributor

🎆

@tgaddair
Copy link
Contributor Author

Hey @Borda @williamFalcon looks like the Drone GPU test timeout was recently changed from 30 minutes to 15 minutes. Before this PR, those tests took about 4:30 minutes to run, and were taking about 18 minutes with this PR.

However, 10 minutes of that was attributed to the time to build Apex. As I mentioned in a previous comment, it looks like Apex was failing to install correctly before due to the lack of the nvcc compiler in the image you were using. The new image has nvcc, and can successfully build Apex, but takes a very long time.

I just ran a test where I removed the line to install Apex, and the tests now pass in about 8:30 minutes (less than the time for CircleCI to finish). I believe this is consistent with the current test behavior, but I wanted to get your thoughts on Apex: do you feel it's worth building it in these tests and waiting the extra 10 minutes? If so, we can restore it in a follow-up and bump up the test timeout.

@tgaddair
Copy link
Contributor Author

🎆

Thanks for merging!

@tgaddair tgaddair deleted the horovod branch April 22, 2020 21:42
@Borda
Copy link
Member

Borda commented Apr 22, 2020

However, 10 minutes of that was attributed to the time to build Apex. As I mentioned in a previous comment, it looks like Apex was failing to install correctly before due to the lack of the nvcc compiler in the image you were using. The new image has nvcc, and can successfully build Apex, but takes a very long time.

I just ran a test where I removed the line to install Apex, and the tests now pass in about 8:30 minutes (less than the time for CircleCI to finish). I believe this is consistent with the current test behavior, but I wanted to get your thoughts on Apex: do you feel it's worth building it in these tests and waiting the extra 10 minutes? If so, we can restore it in a follow-up and bump up the test timeout.

in my opinion it just another reason to create own test image and use it for all CI as we do not want to spend most of the machine time on repetitive building/installing dependencies

maybe I am missing something but without apex there is no amp support, right so the test shall fail...?

This was referenced Apr 23, 2020
@mcarilli
Copy link

mcarilli commented Jun 25, 2020

Re #1561 (comment), after talking to @tgaddair and Meet Shah on pytorch slack, when using Horovod's DistributedOptimizer + native Amp, you need to ensure grads are synced across processes before the unscaling+infchecking. In other words, you need the following pattern:

scaler.scale(loss).backward()

opt.synchronize()

# if separate scaler.unscale_(optimizer) is
# needed, eg to allow clipping unscaled gradients,
# it should come here, after opt.synchronize()

with opt.skip_synchronize():
   scaler.step(opt)

scaler.update()

I think a similar pattern was needed with apex.

@Borda
Copy link
Member

Borda commented Jun 25, 2020

@mcarilli mind send a PR? ❤️

@mcarilli
Copy link

mcarilli commented Jun 26, 2020

@Borda I contacted @williamFalcon on pytorch slack, he said he was refactoring horovod integration already and would ping me for review.

@tgaddair
Copy link
Contributor Author

Hey @mcarilli, I think the AMP integration should already be in place with the Horovod backend. Are you seeing issues when trying to use it?

@mcarilli
Copy link

I haven't seen any issues, just wanted to remind about the synchronize() pattern. If it's already taken care of, ignore me.

@tgaddair
Copy link
Contributor Author

Thanks for clarifying and raising the issue. We should definitely double check!

@Borda Borda modified the milestones: 0.7.4, v0.7.x Apr 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement priority: 0 High priority task
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for Horovod as a distributed backend
6 participants