Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated sync bn #2838

Merged
merged 13 commits into from
Aug 5, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added ddp_spawn test
ananyahjha93 committed Aug 5, 2020
commit 984c5db0681dad31454f29a1690edc4623162276
204 changes: 0 additions & 204 deletions pl_examples/basic_examples/sync_bn.py

This file was deleted.

14 changes: 0 additions & 14 deletions pl_examples/test_examples.py
Original file line number Diff line number Diff line change
@@ -25,20 +25,6 @@ def test_gpu_template(cli_args):
run_cli()


@pytest.mark.parametrize(
'cli_args',
['--max_epochs 1 --max_steps 3 --num_nodes 1 --gpus 2 --dist_backend ddp_spawn --bn_sync']
)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_sync_bn(cli_args):
"""Test running CLI for an example with sync bn."""
from pl_examples.basic_examples.sync_bn import run_cli

cli_args = cli_args.split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
run_cli()


# @pytest.mark.parametrize('cli_args', ['--max_epochs 1 --max_steps 3 --num_nodes 1 --gpus 2'])
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
# def test_multi_node_ddp(cli_args):
51 changes: 51 additions & 0 deletions tests/base/datamodules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader

from pytorch_lightning.core.datamodule import LightningDataModule
from tests.base.datasets import TrialMNIST
from torchvision.datasets import MNIST
from torch.utils.data.distributed import DistributedSampler


class TrialMNISTDataModule(LightningDataModule):
@@ -36,3 +40,50 @@ def val_dataloader(self):

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)


class MNISTDataModule(LightningDataModule):
def __init__(self, data_dir: str = './', batch_size=32, dist_sampler=False):
super().__init__()

self.dist_sampler = dist_sampler
self.data_dir = data_dir
self.batch_size = batch_size

self.transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self.dims = (1, 28, 28)

def prepare_data(self):
# download only
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

def setup(self, stage=None):

# Assign train/val datasets for use in dataloaders
# TODO: need to split using random_split once updated to torch >= 1.6
if stage == 'fit' or stage is None:
self.mnist_train = MNIST(self.data_dir, train=True, transform=self.transforms)

# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transforms)

def train_dataloader(self):
dist_sampler = None
if self.dist_sampler:
dist_sampler = DistributedSampler(self.mnist_train, shuffle=False)

return DataLoader(
self.mnist_train, batch_size=self.batch_size, sampler=dist_sampler, shuffle=False
)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False)
110 changes: 110 additions & 0 deletions tests/models/test_sync_batchnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import os
import math
import numpy as np
from argparse import ArgumentParser

import pytest
from collections import namedtuple
import tests.base.develop_utils as tutils

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from tests.base.datamodules import MNISTDataModule


pl.seed_everything(234)
FLOAT16_EPSILON = np.finfo(np.float16).eps


class SyncBNModule(pl.LightningModule):
def __init__(self, gpu_count=1, **kwargs):
super().__init__()

self.gpu_count = gpu_count
self.bn_targets = None
if 'bn_targets' in kwargs:
self.bn_targets = kwargs['bn_targets']

self.linear = nn.Linear(28 * 28, 10)
self.bn_layer = nn.BatchNorm1d(28 * 28)

def forward(self, x, batch_idx):
with torch.no_grad():
out_bn = self.bn_layer(x.view(x.size(0), -1))

if self.bn_targets:
bn_target = self.bn_targets[batch_idx]

# executes on both GPUs
bn_target = bn_target[self.trainer.local_rank::self.gpu_count]
bn_target = bn_target.to(out_bn.device)
assert torch.sum(torch.abs(bn_target - out_bn)) < FLOAT16_EPSILON

out = self.linear(out_bn)

return out, out_bn

def training_step(self, batch, batch_idx):
x, y = batch

y_hat, _ = self(x, batch_idx)
loss = F.cross_entropy(y_hat, y)

return pl.TrainResult(loss)

def configure_optimizers(self):
return torch.optim.Adam(self.linear.parameters(), lr=0.02)

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_sync_batchnorm_ddp(tmpdir):
tutils.set_random_master_port()

parent_parser = ArgumentParser(add_help=False)

# define datamodule and dataloader
dm = MNISTDataModule()
dm.prepare_data()
dm.setup(stage=None)

train_dataloader = dm.train_dataloader()
model = SyncBNModule()

bn_outputs = []

# shuffle is false by default
for batch_idx, batch in enumerate(train_dataloader):
x, _ = batch

_, out_bn = model.forward(x, batch_idx)
bn_outputs.append(out_bn)

# get 3 steps
if batch_idx == 2:
break

bn_outputs = [x.cuda() for x in bn_outputs]

# reset datamodule
# batch-size = 16 because 2 GPUs in DDP
dm = MNISTDataModule(batch_size=16, dist_sampler=True)
dm.prepare_data()
dm.setup(stage=None)
Comment on lines +83 to +84
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the Trainer / LightningModule will call these automatically? @nateraw


model = SyncBNModule(gpu_count=2, bn_targets=bn_outputs)

trainer = Trainer(
gpus=2,
num_nodes=1,
distributed_backend='ddp_spawn',
max_epochs=1,
max_steps=3,
sync_batchnorm=True,
num_sanity_val_steps=0,
replace_sampler_ddp=False,
)

result = trainer.fit(model, dm)
assert result == 1, "Sync batchnorm failing with DDP"