Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error=101 : invalid device ordinal #3791

Closed
rotabulo opened this issue Oct 2, 2020 · 11 comments · Fixed by #4297
Closed

error=101 : invalid device ordinal #3791

rotabulo opened this issue Oct 2, 2020 · 11 comments · Fixed by #4297
Assignees
Labels
bug Something isn't working distributed Generic distributed-related topic help wanted Open to be worked on priority: 0 High priority task
Milestone

Comments

@rotabulo
Copy link

rotabulo commented Oct 2, 2020

🐛 Bug

When the first entry of CUDA_VISIBLE_DEVICES > 0 and the ddp backend is used, a cuda error invalid device ordinal occurs.
After digging into the library's code I found the source of the issue in this function:
https://github.com/PyTorchLightning/pytorch-lightning/blob/440f837f6d1b5fc44e6f04475fd2af20e2ed370d/pytorch_lightning/accelerators/ddp_backend.py#L151
which I copy and paste here:

def model_to_device(self, model, process_idx, is_master):
1    gpu_idx = process_idx

    # when using ddp, the master process (proc 0) continues running as the main one
    # this means that the local rank will always be 0
    # (even if cuda visible devices has other visible gpus)
    # this means that the master process needs to pull the 0th visible index as the device number
 2   if is_master:
 3       available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
 4       gpu_idx = int(available_gpus[self.trainer.local_rank])

 5   gpu_idx = int(os.environ.get('PL_DDP_PID', gpu_idx))

 6   self.trainer.root_gpu = gpu_idx
 7   torch.cuda.set_device(self.trainer.root_gpu)
 8   model.cuda(self.trainer.root_gpu)

Assume is_master=True , self.trainer.local_rank=0 and CUDA_VISIBLE_DEVICES=4,5,6,7 from the environment.
Then in line 4 gpu_idx becomes 4. In line 5 gpu_idx remains unchanged because PL_DDP_PID is not defined.
Finally in line 7 we get the error because the device indexing has to be in the range [0,4). but we try to set the device to 4.
The problem here is that gpu_idx is taken in absolute terms in line 4, while it should use a relative indexing in line 7.

@rotabulo rotabulo added bug Something isn't working help wanted Open to be worked on labels Oct 2, 2020
@rotabulo
Copy link
Author

rotabulo commented Oct 2, 2020

I managed to fix it by adding the following line
os.environ['PL_DDP_PID'] = str(self.trainer.data_parallel_device_ids[0])
before
https://github.com/PyTorchLightning/pytorch-lightning/blob/440f837f6d1b5fc44e6f04475fd2af20e2ed370d/pytorch_lightning/accelerators/ddp_backend.py#L109

@edenlightning
Copy link
Contributor

hey! which version of lightning are you using?

@rotabulo
Copy link
Author

rotabulo commented Oct 2, 2020

hey! which version of lightning are you using?

@edenlightning I took the most recent one from master to be sure it was not already fixed.

@edenlightning edenlightning added the distributed Generic distributed-related topic label Oct 2, 2020
@edenlightning edenlightning added this to the 0.9.x milestone Oct 2, 2020
@edenlightning edenlightning modified the milestones: 0.9.x, 1.0 Oct 4, 2020
@williamFalcon
Copy link
Contributor

williamFalcon commented Oct 4, 2020

this code is old :)

Try the version on master.

As you see, this has been fixed.

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/accelerators/ddp_backend.py#L173-L176

@williamFalcon
Copy link
Contributor

Tested with this command and it worked:

CUDA_VISIBLE_DEVICES='2,3' python pl_examples/basic_examples/autoencoder.py --distributed_backend 'ddp' --gpus '0, 1'

@catalys1
Copy link
Contributor

catalys1 commented Oct 12, 2020

I'm having this same problem still. I'm on 1.0.0rc5. When I run something like python train.py --gpus 0,2 --distributed_backend ddp then the script crashes (and hangs, I have to kill -9 it) with an error like this:

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,2]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,2]
initializing ddp: GLOBAL_RANK: 2, MEMBER: 3/2
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/2
THCudaCheck FAIL file=/pytorch/torch/csrc/cuda/Module.cpp line=59 error=101 : invalid device ordinal
Traceback (most recent call last):
  ...
  File "/home/catalys1/pylt/lib/python3.8/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 185, in model_to_device
    torch.cuda.set_device(self.trainer.root_gpu)
  File "/home/catalys1/pylt/lib/python3.8/site-packages/torch/cuda/__init__.py", line 281, in set_device
    torch._C._cuda_setDevice(device)
RuntimeError: cuda runtime error (101) : invalid device ordinal at /pytorch/torch/csrc/cuda/Module.cpp:59

However, if I run like @williamFalcon did, with

CUDA_VISIBLE_DEVICES='0,2' python train.py --gpus 0,1 --distributed_backend ddp

then it runs just fine. However, I thought part of the purpose of the --gpus flag was so that it wasn't necessary to specify CUDA_VISIBLE_DEVICES manually?

@williamFalcon
Copy link
Contributor

williamFalcon commented Oct 12, 2020

it’s not required to specify visible devices... i only did it because that was your example.

gpus should be a string when called via CLI

“0, 2”

@catalys1
Copy link
Contributor

Right. I was just pointing out that when I explicitly specify CUDA_VISIBLE_DEVICES it works, but if I try to specify a subset of GPUs to use without setting CUDA_VISIBLE_DEVICES, it gives me an error; both with or without quotes around the GPU list.

@williamFalcon
Copy link
Contributor

Got it. Ok, i think something is getting lost in translation haha. Could you please:

  1. replicate using BoringModel.
  2. Write out the exact command you are using that DOES work.
  3. Write out the exact command you are using that does NOT work.

Thank you! Very excited to track this down :)

@williamFalcon williamFalcon reopened this Oct 13, 2020
@catalys1
Copy link
Contributor

It won't work on colab, since its a issue with ddp. Here's a BoringModel script that reproduces it on my machine:

# -*- coding: utf-8 -*-
"""The BoringModel.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1HvWVVTK8j2Nj52qU4Q4YCyzOm0_aLQF3


# The Boring Model
Replicate a bug you experience, using this model.

[Remember! we're always available for support on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)

---
## Setup env
"""

lightning_version = '1.0.0' #@param ["1.0.0", "0.10.0", "0.9.0", "0.8.5", "0.8.0"]


"""---
## Deps
"""

import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
tmpdir = os.getcwd()

"""---
## Data
Random data is best for debugging. If you needs special tensor shapes or batch compositions or dataloaders, modify as needed
"""

# some other options for random data
from pl_bolts.datasets import RandomDataset, DummyDataset, RandomDictDataset

class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

num_samples = 10000

train = RandomDataset(32, num_samples)
train = DataLoader(train, batch_size=32)

val = RandomDataset(32, num_samples)
val = DataLoader(val, batch_size=32)

test = RandomDataset(32, num_samples)
test = DataLoader(test, batch_size=32)

"""---

## Model
Modify this as needed to replicate your bug
"""

import torch
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset

class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

"""---
## Define the test
NOTE: in colab, set progress_bar_refresh_rate high or the screen will freeze because of the rapid tqdm update speed.
"""

def test_x(tmpdir):
    import argparse
    parser = argparse.ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    args = parser.parse_args()

    # init model
    model = BoringModel()

    # Initialize a trainer
    trainer = pl.Trainer.from_argparse_args(
        args,
        max_epochs=1, 
    )

    # Train the model ⚡
    trainer.fit(model, train, val)

"""---
## Run Test
"""

test_x(tmpdir)

This command runs successfully:

CUDA_VISIBLE_DEVICES="1,2" python the_boringmodel.py --gpus=2 --distributed_backend=ddp

This one does not:

python the_boringmodel.py --gpus="1,2" --distributed_backend=ddp

It gives the error I showed in my last post.

@edenlightning edenlightning modified the milestones: 1.0, 1.0.3 Oct 19, 2020
@edenlightning edenlightning assigned awaelchli and unassigned awaelchli Oct 19, 2020
@awaelchli awaelchli self-assigned this Oct 20, 2020
@edenlightning edenlightning added the priority: 0 High priority task label Oct 20, 2020
@awaelchli
Copy link
Contributor

awaelchli commented Oct 21, 2020

I'm working on this but no big breakthrough yet. I'm facing some difficulties because there are several global/env variables that determine the GPU selection. For ddp this is quite difficult to debug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working distributed Generic distributed-related topic help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants