Skip to content

Commit

Permalink
Fix ddp tests + .test() (#2512)
Browse files Browse the repository at this point in the history
* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* fix deprecation warnings

* added base tests for tpu

* added base tests for tpu

* Update pytorch_lightning/trainer/trainer.py

Co-authored-by: Jeremy Jordan <[email protected]>

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

Co-authored-by: Jirka <[email protected]>
Co-authored-by: Jeremy Jordan <[email protected]>
  • Loading branch information
3 people authored Jul 7, 2020
1 parent fb85d49 commit 11069c8
Show file tree
Hide file tree
Showing 26 changed files with 468 additions and 227 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def data_loader(fn):
Warnings:
This decorator deprecated in v0.7.0 and it will be removed v0.9.0.
"""
rank_zero_warn('`data_loader` decorator deprecated in v0.7.0. Will be removed v0.9.0', DeprecationWarning)
rank_zero_warn("`data_loader` decorator deprecated in v0.7.0. It will be removed in v0.9.0", DeprecationWarning)

def inner_fx(self):
return fn(self)
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def experiment(self) -> SummaryWriter:
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment

@experiment.setter
def experiment(self, exp):
self._experiment = exp

@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace],
metrics: Optional[Dict[str, Any]] = None) -> None:
Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
self.num_training_batches = len(self.train_dataloader)
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
else:
self.num_training_batches = self.limit_train_batches
self.num_training_batches = min(len(self.train_dataloader), self.limit_train_batches)

# determine when to check validation
# if int passed in, val checks that often
Expand Down Expand Up @@ -313,7 +313,7 @@ def _reset_eval_dataloader(
if isinstance(limit_eval_batches, float):
num_batches = int(num_batches * limit_eval_batches)
else:
num_batches = limit_eval_batches
num_batches = min(len(dataloader), limit_eval_batches)

elif limit_eval_batches not in (0.0, 1.0):
raise MisconfigurationException(
Expand All @@ -340,8 +340,7 @@ def reset_val_dataloader(self, model: LightningModule) -> None:
model: The current `LightningModule`
"""
if self.is_overridden('validation_step'):
self.num_val_batches, self.val_dataloaders = \
self._reset_eval_dataloader(model, 'val')
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')

def reset_test_dataloader(self, model) -> None:
"""Resets the validation dataloader and determines the number of batches.
Expand Down
60 changes: 44 additions & 16 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def train_fx(trial_hparams, cluster_manager, _):
from time import sleep
import numpy as np
from os.path import abspath
from torch import distributed as dist
import queue

import torch
from pytorch_lightning import _logger as log
Expand Down Expand Up @@ -163,6 +165,10 @@ def train_fx(trial_hparams, cluster_manager, _):
else:
XLA_AVAILABLE = True

pid = os.getpid()
rng1 = np.random.RandomState(pid)
RANDOM_PORTS = rng1.randint(10000, 19999, 100)


class TrainerDDPMixin(ABC):

Expand All @@ -178,6 +184,7 @@ class TrainerDDPMixin(ABC):
use_tpu: bool
default_root_dir: str
progress_bar_callback: ...
checkpoint_callback: ...
num_processes: int
num_nodes: int
node_rank: int
Expand Down Expand Up @@ -377,17 +384,19 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
# don't make this debug... this is good UX
rank_zero_info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')

def set_random_port(self):
def set_random_port(self, force=False):
"""
When running DDP NOT managed by SLURM, the ports might collide
"""
try:
default_port = os.environ['MASTER_PORT']
except Exception:
# use the process id as a seed to a generator for port only
pid = os.getpid()
rng1 = np.random.RandomState(pid)
default_port = rng1.randint(10000, 19999, 1)[0]
# pick a random port first
assert self.num_nodes == 1, 'random port can only be called from single node training'
global RANDOM_PORTS
default_port = RANDOM_PORTS[-1]
RANDOM_PORTS = RANDOM_PORTS[:-1]

# when not forced, use the user port
if not force:
default_port = os.environ.get('MASTER_PORT', default_port)

os.environ['MASTER_PORT'] = str(default_port)

Expand Down Expand Up @@ -446,15 +455,24 @@ def spawn_ddp_children(self, model):
sleep(delay)

local_rank = 0
self.ddp_train(local_rank, model, is_master=True)
results = self.ddp_train(local_rank, q=None, model=model, is_master=True)
del os.environ['WORLD_SIZE']

def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
return results

def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
"""
Entry point into a DP thread
:param gpu_idx:
:param model:
:param cluster_obj:
:return:
Entry point for ddp
Args:
process_idx:
q:
model:
is_master:
proc_offset:
Returns:
"""
# offset the process id if requested
process_idx = process_idx + proc_offset
Expand Down Expand Up @@ -535,7 +553,17 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
model = model.configure_ddp(model, device_ids)

# continue training routine
self.run_pretrain_routine(model)
results = self.run_pretrain_routine(model)

# clean up memory
torch.cuda.empty_cache()

if self.global_rank == 0 and q is not None:
q.put(self.checkpoint_callback.best_model_path)
q.put(results)

if self.global_rank == 0 and self.distributed_backend != 'ddp_spawn':
return results

def save_spawn_weights(self, model):
"""
Expand Down
17 changes: 11 additions & 6 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pytorch_lightning.utilities import move_data_to_device, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities import rank_zero_warn

try:
from apex import amp
Expand Down Expand Up @@ -182,7 +183,8 @@ def single_gpu_train(self, model):
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)

self.run_pretrain_routine(model)
results = self.run_pretrain_routine(model)
return results

def tpu_train(self, tpu_core_idx, model):
# call setup after the ddp process has connected
Expand Down Expand Up @@ -221,6 +223,7 @@ def tpu_train(self, tpu_core_idx, model):

# when training ends on these platforms dump weights to get out of the main process
if self.on_colab_kaggle:
rank_zero_warn('cleaning up... please do not interrupt')
self.save_spawn_weights(model)

def dp_train(self, model):
Expand All @@ -229,12 +232,12 @@ def dp_train(self, model):
if self.is_function_implemented('setup', model):
model.setup('fit')

model.cuda(self.root_gpu)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

model.cuda(self.root_gpu)

# hack forward to do autocast for the user
model_autocast_original_forward = model.forward
if self.use_amp and NATIVE_AMP_AVALAIBLE:
Expand Down Expand Up @@ -264,10 +267,11 @@ def dp_train(self, model):

model = LightningDataParallel(model, device_ids=device_ids)

self.run_pretrain_routine(model)

result = self.run_pretrain_routine(model)
model.forward = model_autocast_original_forward

return result

def horovod_train(self, model):
# call setup after the ddp process has connected
self.setup('fit')
Expand Down Expand Up @@ -325,10 +329,11 @@ def filter_named_parameters(model, optimizer):
# Synchronization will be performed explicitly following backward()
stack.enter_context(optimizer.skip_synchronize())

self.run_pretrain_routine(model)
result = self.run_pretrain_routine(model)

# Make sure all workers have finished training before returning to the user
hvd.join()
return result


def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def _evaluate(
if self.is_overridden('test_end', model=model):
# TODO: remove in v1.0.0
eval_results = model.test_end(outputs)
rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed v1.0.'
rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed in v1.0.'
' Use `test_epoch_end` instead.', DeprecationWarning)

elif self.is_overridden('test_epoch_end', model=model):
Expand All @@ -335,7 +335,7 @@ def _evaluate(
if self.is_overridden('validation_end', model=model):
# TODO: remove in v1.0.0
eval_results = model.validation_end(outputs)
rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed v1.0.'
rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed in v1.0.'
' Use `validation_epoch_end` instead.', DeprecationWarning)

elif self.is_overridden('validation_epoch_end', model=model):
Expand Down Expand Up @@ -391,6 +391,7 @@ def run_evaluation(self, test_mode: bool = False):
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)

# enable no returns
callback_metrics = {}
if eval_results is not None and len(eval_results) > 0:
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results)

Expand Down Expand Up @@ -428,6 +429,8 @@ def run_evaluation(self, test_mode: bool = False):
else:
self.on_validation_end()

return callback_metrics

def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
# make dataloader_idx arg in validation_step optional
args = [batch, batch_idx]
Expand Down
Loading

0 comments on commit 11069c8

Please sign in to comment.