diff --git a/ddp_tutorial.md b/ddp_tutorial.md index 4ae7a45..9de0633 100644 --- a/ddp_tutorial.md +++ b/ddp_tutorial.md @@ -26,7 +26,9 @@ Apex provides their own [version](https://github.com/NVIDIA/apex/tree/master/exa This [tutorial](http://www.telesens.co/2019/04/04/distributed-data-parallel-training-using-pytorch-on-aws/) has a good description of what's going on under the hood and how it's different from `nn.DataParallel`. However, it doesn't have code examples of how to use `nn.DataParallel`. +## Outline +This tutorial is really directed at people who are already familiar with training neural network models in Pytorch, and I won't go over any of those parts of the code. I'll begin by summarizing the big picture. I then show a minimum working example of training on MNIST using on GPU. I modify this example to train on multiple GPUs, possibly across multiple nodes, and explain the changes line by line. Importantly, I also explain how to run the code. As a bonus, I also demonstrate how to use Apex to do easy mixed-precision distribued training. ## The big picture @@ -36,7 +38,7 @@ Multiprocessing with `DistributedDataParallel` duplicates the model across multi During training, each process loads its own minibatches from disk and passes them to its GPU. Each GPU does its own forward pass, and then the gradients are all-reduced across the GPUs. Gradients for each layer do not depend on previous layers, so the gradient all-reduce is calculated concurrently with the backwards pass to futher alleviate the networking bottleneck. At the end of the backwards pass, every node has the averaged gradients, ensuring that the model weights stay synchronized. -All this requires that the multiple processes, possibly on multiple nodes, are synchronized and communicate. Pytorch does this through its [`distributed.init_process_group`](https://pytorch.org/docs/stable/distributed.html#initialization) function. Furthermore, each process needs to know which slice of the data to work on so that the batches are non-overlapping. Pytorch provides [`nn.utils.data.DistributedSampler`](https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html) to accomplish this. +All this requires that the multiple processes, possibly on multiple nodes, are synchronized and communicate. Pytorch does this through its [`distributed.init_process_group`](https://pytorch.org/docs/stable/distributed.html#initialization) function. This function needs to know where to find process 0 so that all the processes can sync up and the total number of processes to expect. Each individual process also needs to know the total number of processes as well as its rank within the processes and which GPU to use. It's common to call the total number of processes the *world size*. Finally, each process needs to know which slice of the data to work on so that the batches are non-overlapping. Pytorch provides [`nn.utils.data.DistributedSampler`](https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html) to accomplish this. ## Minimum working examples with explanations