Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error with ddp #4033

Closed
blizda opened this issue Oct 9, 2020 · 13 comments
Closed

Error with ddp #4033

blizda opened this issue Oct 9, 2020 · 13 comments
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@blizda
Copy link

blizda commented Oct 9, 2020

🐛 Bug

When I run ddp with sample script I get the following error:

Variable._execution_engine.run_backward(
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases yet.
Exception raised from mark_variable_ready at /opt/conda/conda-bld/pytorch_1595629411241/work/torch/csrc/distributed/c10d/reducer.cpp:453 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x4d (0x7fec3631377d in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10d::Reducer::mark_variable_ready(c10d::Reducer::VariableIndex) + 0x4cd (0x7fec700ec93d in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #2: c10d::Reducer::autograd_hook(c10d::Reducer::VariableIndex) + 0xeb (0x7fec700ed17b in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #3: <unknown function> + 0xa962b6 (0x7fec700ed2b6 in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #4: <unknown function> + 0xa9d366 (0x7fec700f4366 in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #5: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x4dd (0x7fec6b85893d in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #6: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x451 (0x7fec6b85a401 in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #7: torch::autograd::Engine::execute_with_graph_task(std::shared_ptr<torch::autograd::GraphTask> const&, std::shared_ptr<torch::autograd::Node>) + 0x25c (0x7fec6b8579fc in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #8: torch::autograd::python::PythonEngine::execute_with_graph_task(std::shared_ptr<torch::autograd::GraphTask> const&, std::shared_ptr<torch::autograd::Node>) + 0x3c (0x7fec6fb7c60c in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #9: torch::autograd::Engine::execute(std::vector<torch::autograd::Edge, std::allocator<torch::autograd::Edge> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, bool, bool, std::vector<torch::autograd::Edge, std::allocator<torch::autograd::Edge> > const&) + 0x803 (0x7fec6b856e53 in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #10: torch::autograd::python::PythonEngine::execute(std::vector<torch::autograd::Edge, std::allocator<torch::autograd::Edge> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, bool, bool, std::vector<torch::autograd::Edge, std::allocator<torch::autograd::Edge> > const&) + 0x4e (0x7fec6fb7c3fe in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #11: THPEngine_run_backward(THPEngine*, _object*, _object*) + 0xa19 (0x7fec6fb7d0b9 in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #26: torch::autograd::PyNode::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x183 (0x7fec6fb84803 in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #27: <unknown function> + 0x30d1017 (0x7fec6b85e017 in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #28: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x1400 (0x7fec6b859860 in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #29: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x451 (0x7fec6b85a401 in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #30: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x89 (0x7fec6b852579 in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #31: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x4a (0x7fec6fb7c1ba in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #32: <unknown function> + 0xc819d (0x7fec7269f19d in /home/dbliznyuk/.conda/envs/test_crash/lib/python3.8/site-packages/torch/lib/../../../.././libstdc++.so.6)
frame #33: <unknown function> + 0x76db (0x7fec8b5116db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #34: clone + 0x3f (0x7fec8b23a88f in /lib/x86_64-linux-gnu/libc.so.6)

but with dp all going normal

Code sample

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer
from linformer_pytorch import LinformerLM
from torch.utils.data import DataLoader
    
class MyDummyDataset(torch.utils.data.IterableDataset):
    def __init__(self):
        super(MyDummyDataset).__init__()

    def __len__(self):
        return 354800

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            yield torch.randint(20000, (5120,)), torch.randint(20000, (5120,))
        else:
            yield torch.randint(20000, (5120,)), torch.randint(20000, (5120,))

class LiLinformer(pl.LightningModule):
    
    def __init__(self):
        
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.model = LinformerLM(
            num_tokens=self.tokenizer.vocab_size,  # Number of tokens in the LM
            input_size=5120,  # Dimension 1 of the input
            channels=128,  # Dimension 2 of the input
            dim_d=None,
            # Overwrites the inner dim of the attention heads. If None, sticks with the recommended channels // nhead, as in the "Attention is all you need" paper
            dim_k=128,  # The second dimension of the P_bar matrix from the paper
            dim_ff=128,  # Dimension in the feed forward network
            dropout_ff=0.15,  # Dropout for feed forward network
            nhead=16,  # Number of attention heads
            depth=12,  # How many times to run the model
            dropout=0.1,  # How much dropout to apply to P_bar after softmax
            activation="gelu",
            # What activation to use. Currently, only gelu and relu supported, and only on ff network.
            checkpoint_level="C2",  # What checkpoint level to use. For more information, see below.
            parameter_sharing="layerwise",  # What level of parameter sharing to use. For more information, see below.
            k_reduce_by_layer=0,
            # Going down `depth`, how much to reduce `dim_k` by, for the `E` and `F` matrices. Will have a minimum value of 1.
            full_attention=False,
            # Use full attention instead, for O(n^2) time and space complexity. Included here just for comparison
            include_ff=True,  # Whether or not to include the Feed Forward layer
            w_o_intermediate_dim=None,
            # If not None, have 2 w_o matrices, such that instead of `dim*nead,channels`, you have `dim*nhead,w_o_int`, and `w_o_int,channels`
            emb_dim=128,  # If you want the embedding dimension to be different than the channels for the Linformer
        )

    def forward(self, inputs):
        output = self.model(inputs)
        return output

    def training_step(self, inputs, mm):
        inp, labels = inputs
        loss_mx = labels != -100
        output = self.model(inp)
        output = output[loss_mx].view(-1, self.tokenizer.vocab_size)
        labels = labels[loss_mx].view(-1)
        loss = F.cross_entropy(output, labels)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def train_dataloader(self):
        ds = MyDummyDataset()
        dataloader = DataLoader(ds, batch_size=4)
        return dataloader
    
def run_main():
    lt_model = LiLinformer()
    trainer = pl.Trainer(max_epochs=20, gpus=2, distributed_backend='ddp')
    trainer.fit(lt_model)
    
if __name__ == "__main__":
    run_main()

Environment

  • CUDA:
    - GPU:
    - TITAN RTX
    - TITAN RTX
    - available: True
    - version: 10.1
  • Packages:
    - numpy: 1.19.1
    - pyTorch_debug: False
    - pyTorch_version: 1.6.0
    - pytorch-lightning: 0.10.0
    - tqdm: 4.50.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.8.5
    - version: Update Lightning compatibility with PyTorch 1.2.0 #79-Ubuntu SMP Tue Nov 12 10:36:11 UTC 2019
@blizda blizda added bug Something isn't working help wanted Open to be worked on labels Oct 9, 2020
@github-actions
Copy link
Contributor

github-actions bot commented Oct 9, 2020

Hi! thanks for your contribution!, great first issue!

@edenlightning
Copy link
Contributor

Hey! Can you try to reproduce the error running our base model for testing?

@blizda
Copy link
Author

blizda commented Oct 10, 2020

Hey! Can you try to reproduce the error running our base model for testing?

This issue don't reproduces with testing model

@rohitgr7
Copy link
Contributor

not 100% sure but can you try initializing self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') in setup method?

@blizda
Copy link
Author

blizda commented Oct 10, 2020

not 100% sure but can you try initializing self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') in setup method?

I just removed tokenizer from script and error still reproducing

@rohitgr7
Copy link
Contributor

rohitgr7 commented Oct 10, 2020

but you are using it here output = output[loss_mx].view(-1, self.tokenizer.vocab_size)

@blizda
Copy link
Author

blizda commented Oct 10, 2020

but you are using it here output = output[loss_mx].view(-1, self.tokenizer.vocab_size)

I replase it by number(20000)

@rohitgr7
Copy link
Contributor

can you update the code above?

@awaelchli
Copy link
Contributor

This issue don't reproduces with testing model

Then that's a good place to start looking for differences.
Can you share that code here, so we can spot the mistake?

@blizda
Copy link
Author

blizda commented Oct 10, 2020

This issue don't reproduces with testing model

Then that's a good place to start looking for differences.
Can you share that code here, so we can spot the mistake?

Test code:

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# --------------------------------------------
# --------------------------------------------
# --------------------------------------------
# USE THIS MODEL TO REPRODUCE A BUG YOU REPORT
# --------------------------------------------
# --------------------------------------------
# --------------------------------------------
import os
import torch
from torch.utils.data import Dataset
from pytorch_lightning import Trainer, LightningModule


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len
    
class RandomIteratableDataset(torch.utils.data.IterableDataset):
    def __init__(self, size, length):
        self.len = length
        self.size = size
        self.data = torch.randn(length, size)

    def __iter__(self):
        yield torch.randn(20000, (self.size))

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self):
        """
        Testing PL Module
        Use as follows:
        - subclass
        - modify the behavior for what you want
        class TestModel(BaseTestModel):
            def training_step(...):
                # do your own thing
        or:
        model = BaseTestModel()
        model.training_epoch_end = None
        """
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def step(self, x):
        x = self.layer(x)
        out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
        return out

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def validation_epoch_end(self, outputs) -> None:
        torch.stack([x['x'] for x in outputs]).mean()

    def test_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"y": loss}

    def test_epoch_end(self, outputs) -> None:
        torch.stack([x["y"] for x in outputs]).mean()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]


def run_test():
    class TestModel(BoringModel):

        def on_train_epoch_start(self) -> None:
            print('override any method to prove your bug')

    # fake data
    train_data = torch.utils.data.DataLoader(RandomIteratableDataset(32, 64))
    val_data = torch.utils.data.DataLoader(RandomIteratableDataset(32, 64))
    test_data = torch.utils.data.DataLoader(RandomIteratableDataset(32, 64))

    # model
    model = TestModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        max_epochs=1,
        weights_summary=None,
        gpus=2, 
        distributed_backend='ddp'
    )
    trainer.fit(model, train_data, val_data)
    trainer.test(test_dataloaders=test_data)


if __name__ == '__main__':
    run_test()

Code, for reproducing error:

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from linformer_pytorch import LinformerLM
from torch.utils.data import DataLoader

class MyDummyDataset(torch.utils.data.IterableDataset):
    def __init__(self):
        super(MyDummyDataset).__init__()

    def __len__(self):
        return 354800

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            yield torch.randint(20000, (5120,)), torch.randint(20000, (5120,))
        else:
            yield torch.randint(20000, (5120,)), torch.randint(20000, (5120,))

class LiLinformer(pl.LightningModule):
    
    def __init__(self):
        
        super().__init__()
        self.model = LinformerLM(
            num_tokens=20000,  # Number of tokens in the LM
            input_size=5120,  # Dimension 1 of the input
            channels=128,  # Dimension 2 of the input
            dim_d=None,
            # Overwrites the inner dim of the attention heads. If None, sticks with the recommended channels // nhead, as in the "Attention is all you need" paper
            dim_k=128,  # The second dimension of the P_bar matrix from the paper
            dim_ff=128,  # Dimension in the feed forward network
            dropout_ff=0.15,  # Dropout for feed forward network
            nhead=16,  # Number of attention heads
            depth=12,  # How many times to run the model
            dropout=0.1,  # How much dropout to apply to P_bar after softmax
            activation="gelu",
            # What activation to use. Currently, only gelu and relu supported, and only on ff network.
            checkpoint_level="C2",  # What checkpoint level to use. For more information, see below.
            parameter_sharing="none",  # What level of parameter sharing to use. For more information, see below.
            k_reduce_by_layer=0,
            # Going down `depth`, how much to reduce `dim_k` by, for the `E` and `F` matrices. Will have a minimum value of 1.
            full_attention=False,
            # Use full attention instead, for O(n^2) time and space complexity. Included here just for comparison
            include_ff=True,  # Whether or not to include the Feed Forward layer
            w_o_intermediate_dim=None,
            # If not None, have 2 w_o matrices, such that instead of `dim*nead,channels`, you have `dim*nhead,w_o_int`, and `w_o_int,channels`
            emb_dim=128,  # If you want the embedding dimension to be different than the channels for the Linformer
        )
        
    def forward(self, inputs):
        output = self.model(inputs)
        return output

    def training_step(self, inputs, mm):
        inp, labels = inputs
        loss_mx = labels != -100
        output = self.model(inp)
        output = output[loss_mx].view(-1, 20000)
        labels = labels[loss_mx].view(-1)
        loss = F.cross_entropy(output, labels)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
    
def run_main():
    lt_model = LiLinformer()
    train_data = torch.utils.data.DataLoader(MyDummyDataset(), batch_size=4)
    val_data = torch.utils.data.DataLoader(MyDummyDataset(), batch_size=4)
    trainer = pl.Trainer(max_epochs=2, gpus=2, distributed_backend='ddp')
    trainer.fit(lt_model, train_data, val_data)
    
if __name__ == "__main__":
    run_main()

I take LinformerLM from this repo
In model code had one .to() operation, but after removing it, error still reproducing.

Also, if I trying to stat training this model without lightning, using just vanilla Pytorch, I also get error. After this simple fix vanilla Pytorch ddp run without error, but lightning still reproducing it.

@awaelchli
Copy link
Contributor

awaelchli commented Oct 11, 2020

Looking at the error message at the top of what you posted:

RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the forward function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple checkpoint functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases yet.

I found that your Lilinformer model uses
torch.utils.checkpoint

which looks like is the cause of the issue. I'm not sure if something needs to be enabled to make checkpointing possible in Lightning. I would first make sure that your model does not do what is described in the point 2) in the hint above.

@tatp22
Copy link

tatp22 commented Oct 13, 2020

Hey, author of the other repo here. I think that error 2 is happening:

2) Reused parameters in multiple reentrant backward passes. 
For example, if you use multiple `checkpoint` functions to wrap the same part of your model,
it would result in the same set of parameters been used by different reentrant 
backward passes multiple times, and hence marking a variable ready multiple times. 
DDP does not support such use cases yet.

Here, the model is being run with checkpointing and parameter sharing, and this is causing this error to appear since the same variable is being marked ready more than once. Switch the parameter_sharing option to "none" and the problem should go away, or one can also change the model such that there is no checkpointing, by setting checkpoint_level="C0".

@blizda
Copy link
Author

blizda commented Oct 19, 2020

Thanks all for help. This issue don't connect with lightning, this issue comes from vanilla torch. Sorry for miss reporting

@blizda blizda closed this as completed Oct 19, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

5 participants