Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP fails with DDPPlugin and num_nodes>1 (with SLURM) #7429

Closed
jopo666 opened this issue May 7, 2021 · 1 comment · Fixed by #7026
Closed

DDP fails with DDPPlugin and num_nodes>1 (with SLURM) #7429

jopo666 opened this issue May 7, 2021 · 1 comment · Fixed by #7026
Assignees
Labels
bug Something isn't working distributed Generic distributed-related topic environment: slurm help wanted Open to be worked on priority: 1 Medium priority task

Comments

@jopo666
Copy link

jopo666 commented May 7, 2021

🐛 Bug

DDPPlugin crashes my training scripts when training models on multiple nodes (using SLURM).

When I use multiple GPUs on 1 node with the plugin -> all gucci.
When I use multiple GPUs on multiple nodes without the plugin -> all gucci.
When I use multiple GPUs on multiple nodes with the plugin -> crashes 😢

So when I run...

sbatch submit.sh debug.py --num_nodes 2 --num_gpus 4 --ddp_plugin

Code fails with...

----------------------------------
Total of 8 GPUs over 2 nodes.
Conda environment = DDP_Fail
pytorch-lightning 1.3.0
Running at my_secret_server.fi
Python command:
  python3 ~/debug.py --num_gpus 4 --num_nodes 2 --ddp_plugin
----------------------------------
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/8
initializing ddp: GLOBAL_RANK: 7, MEMBER: 8/8
initializing ddp: GLOBAL_RANK: 6, MEMBER: 7/8
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/8
initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/8
initializing ddp: GLOBAL_RANK: 3, MEMBER: 4/8
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/8
initializing ddp: GLOBAL_RANK: 2, MEMBER: 3/8
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.

Traceback (most recent call last):

...

ValueError: Invalid rank 5, rank should be in the interval [0, 3]
ValueError: Invalid rank 6, rank should be in the interval [0, 3]
ValueError: Invalid rank 7, rank should be in the interval [0, 3]

To Reproduce

debug.py:

import os
import argparse

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DDPPlugin

class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run(args):
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    
    if args.ddp_plugin:
        plugin = DDPPlugin(find_unused_parameters=False)
    else:
        plugin = None

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        weights_summary=None,
        gpus=args.num_gpus,
        num_nodes=args.num_nodes,
        accelerator='ddp' if args.num_gpus*args.num_nodes > 1 else None,
        plugins=plugin,
    )
    trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)
    trainer.test(model, test_dataloaders=test_data)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_nodes', type=int, default=1, metavar='',
                        help='[Default: %(default)s]')
    parser.add_argument('--num_gpus', type=int, default=1, metavar='',
                        help='[Default: %(default)s]')
    parser.add_argument('--ddp_plugin', action='store_true',
                        help='[Default: %(default)s]')
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = get_args()
    run(args)

submit.sh

#!/bin/bash

#SBATCH --nodes=2
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=14
#SBATCH -o ~/logs/ddp_fail_%j.txt

echo "----------------------------------"
echo "Total of 8 GPUs over 2 nodes."
echo Conda environment = $CONDA_DEFAULT_ENV
echo $(pip list | grep lightning)

echo "Running at $(hostname)"
echo "Python command:"
CMD="    python3 ~/$@"
echo $CMD
echo "----------------------------------"
srun $CMD
echo "Done!"

Commands

# These are okay.
sbatch submit.sh debug.py --num_nodes 1 --num_gpus 4 
sbatch submit.sh debug.py --num_nodes 2 --num_gpus 4 
sbatch submit.sh debug.py --num_nodes 1 --num_gpus 4 --ddp_plugin

# This fails.
sbatch submit.sh debug.py --num_nodes 2 --num_gpus 4 --ddp_plugin

Expected behaviour

Code shouldn't fail..? :D

Environment

  • PyTorch Version: 1.8.0
  • OS: Ubuntu
  • How you installed PyTorch: conda
  • Python version: 3.8
@jopo666 jopo666 added bug Something isn't working help wanted Open to be worked on labels May 7, 2021
@carmocca carmocca added distributed Generic distributed-related topic priority: 1 Medium priority task environment: slurm labels May 7, 2021
@awaelchli
Copy link
Contributor

Hi
Based on your description, this is a known issue and I linked the PR that fixes it.
Temp solution for you: Set num_nodes in the plugin:

Trainer(plugins=[DDPPlugin(num_nodes=2, ...)], ...)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working distributed Generic distributed-related topic environment: slurm help wanted Open to be worked on priority: 1 Medium priority task
Projects
None yet
3 participants