Skip to content

Commit

Permalink
Merge pull request #4 from PyTorchLightning/master
Browse files Browse the repository at this point in the history
update master
  • Loading branch information
nrupatunga authored Oct 1, 2020
2 parents cc16f1b + e17712e commit 6d11e64
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 161 deletions.
52 changes: 33 additions & 19 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,6 @@ def train(self):
self.__recover_child_process_weights(model, best_path, last_path)
return results

def __recover_child_process_weights(self, model, best_path, last_path):
# transfer back the best path to the trainer
if self.trainer.checkpoint_callback:
self.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also best score

# load last weights
if last_path is not None and not self.trainer.testing:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)

self.trainer.model = model

def ddp_train(self, process_idx, mp_queue, model):
"""
Entry point for ddp
Expand All @@ -95,9 +82,7 @@ def ddp_train(self, process_idx, mp_queue, model):
self.trainer.progress_bar_callback.disable()

# determine which process we are and world size
self.trainer.local_rank = process_idx
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes
self.set_world_ranks(process_idx)

# set warning rank
rank_zero_only.rank = self.trainer.global_rank
Expand All @@ -116,7 +101,7 @@ def ddp_train(self, process_idx, mp_queue, model):
self.trainer.call_setup_hook(model)

# on world_size=0 let everyone know training is starting
if self.trainer.is_global_zero:
if self.trainer.is_global_zero and not torch.distributed.is_initialized():
log.info('-' * 100)
log.info(f'distributed_backend={self.trainer.distributed_backend}')
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
Expand All @@ -126,6 +111,9 @@ def ddp_train(self, process_idx, mp_queue, model):
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# move the model to the correct device
self.model_to_device(model, process_idx)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.setup_optimizers(model)
Expand All @@ -137,7 +125,7 @@ def ddp_train(self, process_idx, mp_queue, model):
model = self.trainer.precision_connector.connect(model)

# DDP spawn already spawned off each process... no need to do anything
device_ids = None
device_ids = self.get_device_ids()

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)
Expand Down Expand Up @@ -174,7 +162,8 @@ def test_step(self, args):
return output

def barrier(self, name: str = None):
torch_distrib.barrier()
if torch_distrib.is_initialized():
torch_distrib.barrier()

def broadcast(self, obj, src=0):
return self.dist.broadcast(obj)
Expand All @@ -186,6 +175,31 @@ def early_stopping_should_stop(self, pl_module):
should_stop = stop == self.trainer.world_size
return should_stop

def set_world_ranks(self, process_idx):
self.trainer.local_rank = process_idx
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx):
model.cpu()

def get_device_ids(self):
device_ids = None
return device_ids

def __recover_child_process_weights(self, model, best_path, last_path):
# transfer back the best path to the trainer
if self.trainer.checkpoint_callback:
self.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also best score

# load last weights
if last_path is not None and not self.trainer.testing:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)

self.trainer.model = model

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
# track the best model path
best_model_path = None
Expand Down
57 changes: 57 additions & 0 deletions tests/backends/test_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
import torch
import os
from tests.backends import ddp_model
from tests.utilities.dist import call_training_script


@pytest.mark.parametrize('cli_args', [
pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'),
])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_multi_gpu_model_ddp_fit_only(tmpdir, cli_args):
# call the script
std, err = call_training_script(ddp_model, cli_args, 'fit', tmpdir, timeout=120)

# load the results of the script
result_path = os.path.join(tmpdir, 'ddp.result')
result = torch.load(result_path)

# verify the file wrote the expected outputs
assert result['status'] == 'complete'


@pytest.mark.parametrize('cli_args', [
pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'),
])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_multi_gpu_model_ddp_test_only(tmpdir, cli_args):
# call the script
call_training_script(ddp_model, cli_args, 'test', tmpdir)

# load the results of the script
result_path = os.path.join(tmpdir, 'ddp.result')
result = torch.load(result_path)

# verify the file wrote the expected outputs
assert result['status'] == 'complete'


# @pytest.mark.parametrize('cli_args', [
# pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'),
# ])
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
# def test_multi_gpu_model_ddp_fit_test(tmpdir, cli_args):
# # call the script
# call_training_script(ddp_model, cli_args, 'fit_test', tmpdir, timeout=20)
#
# # load the results of the script
# result_path = os.path.join(tmpdir, 'ddp.result')
# result = torch.load(result_path)
#
# # verify the file wrote the expected outputs
# assert result['status'] == 'complete'
#
# model_outs = result['result']
# for out in model_outs:
# assert out['test_acc'] > 0.90
71 changes: 71 additions & 0 deletions tests/backends/test_ddp_spawn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytest
import torch

import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from tests.base import EvalModelTemplate
from pytorch_lightning.core import memory
from pytorch_lightning.trainer import Trainer


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_multi_gpu_early_stop_ddp_spawn(tmpdir):
"""Make sure DDP works. with early stopping"""
tutils.set_random_master_port()

trainer_options = dict(
default_root_dir=tmpdir,
early_stop_callback=True,
max_epochs=50,
limit_train_batches=10,
limit_val_batches=10,
gpus=[0, 1],
distributed_backend='ddp_spawn',
)

model = EvalModelTemplate()
tpipes.run_model_test(trainer_options, model)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_multi_gpu_model_ddp_spawn(tmpdir):
tutils.set_random_master_port()

trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=10,
limit_val_batches=10,
gpus=[0, 1],
distributed_backend='ddp_spawn',
progress_bar_refresh_rate=0
)

model = EvalModelTemplate()

tpipes.run_model_test(trainer_options, model)

# test memory helper functions
memory.get_memory_profile('min_max')


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
"""Make sure DDP works with dataloaders passed to fit()"""
tutils.set_random_master_port()

model = EvalModelTemplate()
fit_options = dict(train_dataloader=model.train_dataloader(),
val_dataloaders=model.val_dataloader())

trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
limit_train_batches=0.2,
limit_val_batches=0.2,
gpus=[0, 1],
distributed_backend='ddp_spawn'
)
result = trainer.fit(model, **fit_options)
assert result == 1, "DDP doesn't work with dataloaders passed to fit()."
44 changes: 0 additions & 44 deletions tests/models/data/ddp/train_test_variations.py

This file was deleted.

Loading

0 comments on commit 6d11e64

Please sign in to comment.