diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 8f2721201a124..b540c1b66bab3 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -13,7 +13,7 @@ def data_loader(fn): Warnings: This decorator deprecated in v0.7.0 and it will be removed v0.9.0. """ - rank_zero_warn('`data_loader` decorator deprecated in v0.7.0. Will be removed v0.9.0', DeprecationWarning) + rank_zero_warn("`data_loader` decorator deprecated in v0.7.0. It will be removed in v0.9.0", DeprecationWarning) def inner_fx(self): return fn(self) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 61da82ac7731b..29cdd49c0efbe 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -106,6 +106,10 @@ def experiment(self) -> SummaryWriter: self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) return self._experiment + @experiment.setter + def experiment(self, exp): + self._experiment = exp + @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None) -> None: diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index cfe3f744742f3..31a14bb1cd881 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -221,7 +221,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: self.num_training_batches = len(self.train_dataloader) self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) else: - self.num_training_batches = self.limit_train_batches + self.num_training_batches = min(len(self.train_dataloader), self.limit_train_batches) # determine when to check validation # if int passed in, val checks that often @@ -313,7 +313,7 @@ def _reset_eval_dataloader( if isinstance(limit_eval_batches, float): num_batches = int(num_batches * limit_eval_batches) else: - num_batches = limit_eval_batches + num_batches = min(len(dataloader), limit_eval_batches) elif limit_eval_batches not in (0.0, 1.0): raise MisconfigurationException( @@ -340,8 +340,7 @@ def reset_val_dataloader(self, model: LightningModule) -> None: model: The current `LightningModule` """ if self.is_overridden('validation_step'): - self.num_val_batches, self.val_dataloaders = \ - self._reset_eval_dataloader(model, 'val') + self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val') def reset_test_dataloader(self, model) -> None: """Resets the validation dataloader and determines the number of batches. diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 3b4732ead55a8..5b8d79e47564f 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -122,6 +122,8 @@ def train_fx(trial_hparams, cluster_manager, _): from time import sleep import numpy as np from os.path import abspath +from torch import distributed as dist +import queue import torch from pytorch_lightning import _logger as log @@ -163,6 +165,10 @@ def train_fx(trial_hparams, cluster_manager, _): else: XLA_AVAILABLE = True +pid = os.getpid() +rng1 = np.random.RandomState(pid) +RANDOM_PORTS = rng1.randint(10000, 19999, 100) + class TrainerDDPMixin(ABC): @@ -178,6 +184,7 @@ class TrainerDDPMixin(ABC): use_tpu: bool default_root_dir: str progress_bar_callback: ... + checkpoint_callback: ... num_processes: int num_nodes: int node_rank: int @@ -377,17 +384,19 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids): # don't make this debug... this is good UX rank_zero_info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]') - def set_random_port(self): + def set_random_port(self, force=False): """ When running DDP NOT managed by SLURM, the ports might collide """ - try: - default_port = os.environ['MASTER_PORT'] - except Exception: - # use the process id as a seed to a generator for port only - pid = os.getpid() - rng1 = np.random.RandomState(pid) - default_port = rng1.randint(10000, 19999, 1)[0] + # pick a random port first + assert self.num_nodes == 1, 'random port can only be called from single node training' + global RANDOM_PORTS + default_port = RANDOM_PORTS[-1] + RANDOM_PORTS = RANDOM_PORTS[:-1] + + # when not forced, use the user port + if not force: + default_port = os.environ.get('MASTER_PORT', default_port) os.environ['MASTER_PORT'] = str(default_port) @@ -446,15 +455,24 @@ def spawn_ddp_children(self, model): sleep(delay) local_rank = 0 - self.ddp_train(local_rank, model, is_master=True) + results = self.ddp_train(local_rank, q=None, model=model, is_master=True) + del os.environ['WORLD_SIZE'] - def ddp_train(self, process_idx, model, is_master=False, proc_offset=0): + return results + + def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0): """ - Entry point into a DP thread - :param gpu_idx: - :param model: - :param cluster_obj: - :return: + Entry point for ddp + + Args: + process_idx: + q: + model: + is_master: + proc_offset: + + Returns: + """ # offset the process id if requested process_idx = process_idx + proc_offset @@ -535,7 +553,17 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0): model = model.configure_ddp(model, device_ids) # continue training routine - self.run_pretrain_routine(model) + results = self.run_pretrain_routine(model) + + # clean up memory + torch.cuda.empty_cache() + + if self.global_rank == 0 and q is not None: + q.put(self.checkpoint_callback.best_model_path) + q.put(results) + + if self.global_rank == 0 and self.distributed_backend != 'ddp_spawn': + return results def save_spawn_weights(self, model): """ diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index a04356028229e..78bc22d21589d 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -21,6 +21,7 @@ from pytorch_lightning.utilities import move_data_to_device, NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities import rank_zero_warn try: from apex import amp @@ -182,7 +183,8 @@ def single_gpu_train(self, model): self.optimizers = optimizers self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers) - self.run_pretrain_routine(model) + results = self.run_pretrain_routine(model) + return results def tpu_train(self, tpu_core_idx, model): # call setup after the ddp process has connected @@ -221,6 +223,7 @@ def tpu_train(self, tpu_core_idx, model): # when training ends on these platforms dump weights to get out of the main process if self.on_colab_kaggle: + rank_zero_warn('cleaning up... please do not interrupt') self.save_spawn_weights(model) def dp_train(self, model): @@ -229,12 +232,12 @@ def dp_train(self, model): if self.is_function_implemented('setup', model): model.setup('fit') + model.cuda(self.root_gpu) + # CHOOSE OPTIMIZER # allow for lr schedulers as well self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) - model.cuda(self.root_gpu) - # hack forward to do autocast for the user model_autocast_original_forward = model.forward if self.use_amp and NATIVE_AMP_AVALAIBLE: @@ -264,10 +267,11 @@ def dp_train(self, model): model = LightningDataParallel(model, device_ids=device_ids) - self.run_pretrain_routine(model) - + result = self.run_pretrain_routine(model) model.forward = model_autocast_original_forward + return result + def horovod_train(self, model): # call setup after the ddp process has connected self.setup('fit') @@ -325,10 +329,11 @@ def filter_named_parameters(model, optimizer): # Synchronization will be performed explicitly following backward() stack.enter_context(optimizer.skip_synchronize()) - self.run_pretrain_routine(model) + result = self.run_pretrain_routine(model) # Make sure all workers have finished training before returning to the user hvd.join() + return result def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]: diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 6a0c20bfe6fe0..ee85e01d6038d 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -325,7 +325,7 @@ def _evaluate( if self.is_overridden('test_end', model=model): # TODO: remove in v1.0.0 eval_results = model.test_end(outputs) - rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed v1.0.' + rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed in v1.0.' ' Use `test_epoch_end` instead.', DeprecationWarning) elif self.is_overridden('test_epoch_end', model=model): @@ -335,7 +335,7 @@ def _evaluate( if self.is_overridden('validation_end', model=model): # TODO: remove in v1.0.0 eval_results = model.validation_end(outputs) - rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed v1.0.' + rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed in v1.0.' ' Use `validation_epoch_end` instead.', DeprecationWarning) elif self.is_overridden('validation_epoch_end', model=model): @@ -391,6 +391,7 @@ def run_evaluation(self, test_mode: bool = False): eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode) # enable no returns + callback_metrics = {} if eval_results is not None and len(eval_results) > 0: _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results) @@ -428,6 +429,8 @@ def run_evaluation(self, test_mode: bool = False): else: self.on_validation_end() + return callback_metrics + def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0b279f4e531f0..1b3e053387e96 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -129,6 +129,7 @@ class Trainer( >>> trainer.fit(model, train_loader) 1 >>> trainer.test(model, train_loader) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + 1 """ DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar', 'training_tqdm_dict', 'num_tpu_cores') @@ -894,6 +895,8 @@ def fit( # defined as part of the model, and validation can then be feed to .fit() """ + results = None + # bind logger and other properties self.copy_trainer_model_properties(model) @@ -940,43 +943,37 @@ def fit( elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ): task = int(os.environ['LOCAL_RANK']) - self.ddp_train(task, model) + self.ddp_train(process_idx=task, q=None, model=model) elif self.use_ddp: if self.is_slurm_managing_tasks: task = int(os.environ['SLURM_LOCALID']) - self.ddp_train(task, model) + self.ddp_train(process_idx=task, q=None, model=model) # torchelastic or general non_slurm ddp elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ): task = int(os.environ['LOCAL_RANK']) - self.ddp_train(task, model) + self.ddp_train(process_idx=task, q=None, model=model) elif self.distributed_backend == 'ddp_cpu': - self.set_random_port() - self.model = model - mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,)) + results = self.__run_ddp_spawn(model, nprocs=self.num_processes) elif self.distributed_backend == 'ddp_spawn': - self.set_random_port() - model.share_memory() - - # spin up peers - mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, )) + results = self.__run_ddp_spawn(model, nprocs=self.num_processes) elif self.distributed_backend == 'ddp': self.set_random_port() - self.spawn_ddp_children(model) + results = self.spawn_ddp_children(model) # 1 gpu or dp option triggers training using DP module # easier to avoid NCCL issues elif self.use_dp: - self.dp_train(model) + results = self.dp_train(model) elif self.use_horovod: - self.horovod_train(model) + results = self.horovod_train(model) elif self.single_gpu: - self.single_gpu_train(model) + results = self.single_gpu_train(model) elif self.use_tpu: # pragma: no-cover rank_zero_info(f'training on {self.tpu_cores} TPU cores') @@ -1017,7 +1014,7 @@ def fit( # allow for lr schedulers as well self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) - self.run_pretrain_routine(model) + results = self.run_pretrain_routine(model) # callbacks self.on_fit_end() @@ -1032,12 +1029,30 @@ def fit( # return 1 when finished # used for testing or when we need to know that training succeeded - return 1 + return results or 1 + + def __run_ddp_spawn(self, model, nprocs): + self.set_random_port() + + # pass in a state q + smp = mp.get_context('spawn') + q = smp.SimpleQueue() + + mp.spawn(self.ddp_train, nprocs=nprocs, args=(q, model,)) + + # restore main state with best weights + best_path = q.get() + results = q.get() + if best_path is not None and len(best_path) > 0: + self.checkpoint_callback.best_model_path = best_path + model.load_from_checkpoint(best_path) + + self.model = model + return results def can_prepare_data(self): if self.prepare_data_per_node: return self.local_rank == 0 - else: return self.node_rank == 0 and self.local_rank == 0 @@ -1108,15 +1123,24 @@ def run_pretrain_routine(self, model: LightningModule): # if cluster resets state, the model will update with the saved weights self.model = model - # restore training and model before hpc call + # restore training and model before hpc is called self.restore_weights(model) # when testing requested only run test and return if self.testing: # only load test dataloader for testing # self.reset_test_dataloader(ref_model) - self.run_evaluation(test_mode=True) - return + results = self.run_evaluation(test_mode=True) + + # remove all cuda tensors + if results is not None and isinstance(results, dict) and len(results) > 0: + for k, v in results.items(): + if isinstance(v, torch.Tensor): + results[k] = v.cpu().item() + + return results + else: + return 1 # check if we should run validation during training self.disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \ @@ -1210,57 +1234,65 @@ def test( trainer = Trainer() trainer.test(model, test_dataloaders=test) """ + # -------------------- + # SETUP HOOK + # -------------------- self.setup('test') model_ref = self.model if model is None else model if self.is_function_implemented('setup', model_ref): model_ref.setup('test') - self.barrier('test_setup') - + # if user requests the best checkpoint but we don't have it, error if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0: raise MisconfigurationException( 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.') - # if model is not given (None), ckpt_path is given, - # load the given checkpoint for testing + # -------------------- + # AUTO-LOAD BEST CKPT + # -------------------- + # load the best checkpoint automatically unless model is given + # in which case we use that one if model is None and ckpt_path is not None: # ckpt_path is 'best' so load the best model if ckpt_path == 'best': ckpt_path = self.checkpoint_callback.best_model_path model = self.get_model().load_from_checkpoint(ckpt_path) - self.testing = True + # ---------------------------------------------------- + # AUTO-LOAD BEST CKPT with the model trained in .fit() + # ---------------------------------------------------- + elif model is None and ckpt_path is None: + model = model_ref + # -------------------- + # LOAD DATA + # -------------------- if test_dataloaders is not None: if model: self.__attach_dataloaders(model, test_dataloaders=test_dataloaders) else: self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders) - if model is not None: - self.model = model - self.fit(model) - - # on tpu, .spawn means we don't have a trained model - # TODO: remove TPU spawn - elif self.use_tpu: # pragma: no-cover - # attempt to load weights from a spawn - path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt') - test_model = self.model - if os.path.exists(path) and self.on_colab_kaggle: - test_model = self.load_spawn_weights(self.model) - - self.fit(test_model) - else: - self.run_evaluation(test_mode=True) - + # -------------------- + # RUN TEST SET + # -------------------- + # sets up testing so we short circuit to eval + self.set_random_port(force=True) + self.testing = True + self.model = model + results = self.fit(model) self.testing = False + # -------------------- + # TEAR DOWN HOOK + # -------------------- self.teardown('test') if self.is_function_implemented('teardown'): model_ref = self.get_model() model_ref.teardown('test') + return results + def check_model_configuration(self, model: LightningModule): r""" Checks that the model is configured correctly before training or testing is started. @@ -1321,7 +1353,8 @@ def check_model_configuration(self, model: LightningModule): def barrier(self, name): if self.use_ddp or self.use_ddp2: - torch_distrib.barrier() + pass + # torch_distrib.barrier() if self.on_tpu and XLA_AVAILABLE: # wait for all processes to catch up diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5340c57581cc8..bd55881dd7f38 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -507,7 +507,8 @@ def run_training_epoch(self): def check_checkpoint_callback(self, should_check_val): # when no val loop is present or fast-dev-run still need to call checkpoints # TODO bake this logic into the checkpoint callback - if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val): + should_activate = not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val) + if should_activate: checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] [c.on_validation_end(self, self.get_model()) for c in checkpoint_callbacks] @@ -742,7 +743,6 @@ def call_optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch): model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure, using_lbfgs=True) - # when using 16-bit else: native_amp = self.use_amp and NATIVE_AMP_AVALAIBLE @@ -889,6 +889,12 @@ def run_training_teardown(self): if self.use_ddp or self.use_ddp2: torch_distrib.destroy_process_group() + # clear mem + if self.on_gpu: + model = self.get_model() + model.cpu() + torch.cuda.empty_cache() + def training_forward(self, batch, batch_idx, opt_idx, hiddens): """ Handle forward for each training case (distributed, single gpu, etc...) diff --git a/tests/base/deterministic_model.py b/tests/base/deterministic_model.py index caeb9e882a08a..529d64f799fcd 100644 --- a/tests/base/deterministic_model.py +++ b/tests/base/deterministic_model.py @@ -1,5 +1,6 @@ import numpy as np import torch +from torch import nn from torch.utils.data import Dataset, DataLoader from pytorch_lightning.core.lightning import LightningModule @@ -14,22 +15,25 @@ def __init__(self, weights=None): self.training_step_end_called = False self.training_epoch_end_called = False + self.l1 = nn.Linear(2, 3, bias=False) if weights is None: weights = torch.tensor([ [4, 3, 5], [10, 11, 13] ]).float() - self.l1 = torch.nn.Parameter(weights, requires_grad=True) + p = torch.nn.Parameter(weights, requires_grad=True) + self.l1.weight = p def forward(self, x): - return self.l1.mm(x.float().t()) + return self.l1(x) def step(self, batch, batch_idx): x = batch y_hat = self(x) - assert torch.all(y_hat[0, :] == 15.0) - assert torch.all(y_hat[1, :] == 42.0) + test_hat = y_hat.cpu().detach() + assert torch.all(test_hat[:, 0] == 15.0) + assert torch.all(test_hat[:, 1] == 42.0) out = y_hat.sum() assert out == (42.0 * 3) + (15.0 * 3) @@ -147,4 +151,4 @@ def __len__(self): return 12 def __getitem__(self, idx): - return np.array([0.5, 1.0, 2.0]) + return torch.tensor([0.5, 1.0, 2.0]) diff --git a/tests/base/develop_pipelines.py b/tests/base/develop_pipelines.py index 9ba3dd8d978c0..9424455403b83 100644 --- a/tests/base/develop_pipelines.py +++ b/tests/base/develop_pipelines.py @@ -19,8 +19,7 @@ def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50 # test model loading pretrained_model = load_model_from_checkpoint( trainer.logger, - trainer.checkpoint_callback.dirpath, - path_expt=trainer_options.get('default_root_dir'), + trainer.checkpoint_callback.best_model_path, ) # test new model accuracy @@ -38,6 +37,7 @@ def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50 def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, with_hpc: bool = True): + reset_seed() save_dir = trainer_options['default_root_dir'] @@ -46,11 +46,8 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, wi trainer_options.update(logger=logger) if 'checkpoint_callback' not in trainer_options: - # logger file to get weights - checkpoint = init_checkpoint_callback(logger) - trainer_options.update(checkpoint_callback=checkpoint) + trainer_options.update(checkpoint_callback=True) - # fit model trainer = Trainer(**trainer_options) result = trainer.fit(model) @@ -58,7 +55,7 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, wi assert result == 1, 'amp + ddp model failed to complete' # test model loading - pretrained_model = load_model_from_checkpoint(logger, trainer.checkpoint_callback.dirpath) + pretrained_model = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path) # test new model accuracy test_loaders = model.test_dataloader() diff --git a/tests/base/develop_utils.py b/tests/base/develop_utils.py index 6275c5ea67b64..995e25e090c3e 100644 --- a/tests/base/develop_utils.py +++ b/tests/base/develop_utils.py @@ -5,9 +5,10 @@ # from pl_examples import LightningTemplateModel from pytorch_lightning import seed_everything from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger from tests import TEMP_PATH, RANDOM_PORTS, RANDOM_SEEDS from tests.base.model_template import EvalModelTemplate +import functools def assert_speed_parity_relative(pl_times, pt_times, max_diff: float = 0.1): @@ -36,12 +37,15 @@ def get_default_logger(save_dir, version=None): def get_data_path(expt_logger, path_dir=None): # some calls contain only experiment not complete logger - expt = expt_logger.experiment if hasattr(expt_logger, 'experiment') else expt_logger + # each logger has to have these attributes name, version = expt_logger.name, expt_logger.version + # only the test-tube experiment has such attribute - if hasattr(expt, 'get_data_path'): + if isinstance(expt_logger, TestTubeLogger): + expt = expt_logger.experiment if hasattr(expt_logger, 'experiment') else expt_logger return expt.get_data_path(name, version) + # the other experiments... if not path_dir: if hasattr(expt_logger, 'save_dir') and expt_logger.save_dir: @@ -49,6 +53,7 @@ def get_data_path(expt_logger, path_dir=None): else: path_dir = TEMP_PATH path_expt = os.path.join(path_dir, name, 'version_%s' % version) + # try if the new sub-folder exists, typical case for test-tube if not os.path.isdir(path_expt): path_expt = path_dir @@ -56,20 +61,8 @@ def get_data_path(expt_logger, path_dir=None): def load_model_from_checkpoint(logger, root_weights_dir, module_class=EvalModelTemplate, path_expt=None): - # load trained model - path_expt_dir = get_data_path(logger, path_dir=path_expt) - hparams_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_HPARAMS_FILE) - - checkpoints = [x for x in os.listdir(root_weights_dir) if '.ckpt' in x] - weights_dir = os.path.join(root_weights_dir, checkpoints[0]) - - trained_model = module_class.load_from_checkpoint( - checkpoint_path=weights_dir, - hparams_file=hparams_path - ) - + trained_model = module_class.load_from_checkpoint(root_weights_dir) assert trained_model is not None, 'loading model failed' - return trained_model @@ -90,9 +83,32 @@ def set_random_master_port(): os.environ['MASTER_PORT'] = str(port) -def init_checkpoint_callback(logger, path_dir=None): - exp_path = get_data_path(logger, path_dir=path_dir) - ckpt_dir = os.path.join(exp_path, 'checkpoints') - os.mkdir(ckpt_dir) - checkpoint = ModelCheckpoint(ckpt_dir) +def init_checkpoint_callback(logger): + checkpoint = ModelCheckpoint(logger.save_dir) return checkpoint + + +def pl_multi_process_test(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + + from multiprocessing import Process, Queue + queue = Queue() + + def inner_f(queue, **kwargs): + try: + func(**kwargs) + queue.put(1) + except Exception as e: + import traceback + traceback.print_exc() + queue.put(-1) + + p = Process(target=inner_f, args=(queue,), kwargs=kwargs) + p.start() + p.join() + result = queue.get() + assert result == 1 + + return wrapper diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index c3e5fa3914682..2ba434af26dbb 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -130,4 +130,3 @@ def test_pickling(tmpdir): early_stopping_pickled = cloudpickle.dumps(early_stopping) early_stopping_loaded = cloudpickle.loads(early_stopping_pickled) assert vars(early_stopping) == vars(early_stopping_loaded) - diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 094bcbf1956f6..f74e815086d6f 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -156,7 +156,7 @@ def on_batch_start(self, trainer, pl_module): @pytest.mark.parametrize("logger_class", [ TensorBoardLogger, CometLogger, - #MLFlowLogger, + # MLFlowLogger, NeptuneLogger, TestTubeLogger, WandbLogger, diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 085368af105ef..6af8a90d373af 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -58,7 +58,6 @@ def save_dir(self) -> Optional[str]: """ return None - @property def name(self): return "name" diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index e6df2bbc1c691..d3362abc9ad44 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -32,17 +32,17 @@ def test_tensorboard_hparams_reload(tmpdir): # verify artifacts assert len(os.listdir(os.path.join(folder_path, 'checkpoints'))) == 1 - - # verify tb logs - event_acc = EventAccumulator(folder_path) - event_acc.Reload() - - hparams_data = b'\x12\x84\x01"\x0b\n\tdrop_prob"\x0c\n\nbatch_size"\r\n\x0bin_features"' \ - b'\x0f\n\rlearning_rate"\x10\n\x0eoptimizer_name"\x0b\n\tdata_root"\x0e\n' \ - b'\x0cout_features"\x0c\n\nhidden_dim"\x04\n\x02b1"\x04\n\x02b2' - - assert event_acc.summary_metadata['_hparams_/experiment'].plugin_data.plugin_name == 'hparams' - assert event_acc.summary_metadata['_hparams_/experiment'].plugin_data.content == hparams_data + # + # # verify tb logs + # event_acc = EventAccumulator(folder_path) + # event_acc.Reload() + # + # hparams_data = b'\x12\x84\x01"\x0b\n\tdrop_prob"\x0c\n\nbatch_size"\r\n\x0bin_features"' \ + # b'\x0f\n\rlearning_rate"\x10\n\x0eoptimizer_name"\x0b\n\tdata_root"\x0e\n' \ + # b'\x0cout_features"\x0c\n\nhidden_dim"\x04\n\x02b1"\x04\n\x02b2' + # + # assert event_acc.summary_metadata['_hparams_/experiment'].plugin_data.plugin_name == 'hparams' + # assert event_acc.summary_metadata['_hparams_/experiment'].plugin_data.content == hparams_data def test_tensorboard_automatic_versioning(tmpdir): diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 1c187a8188332..3bd1a10411fef 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -103,7 +103,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): default_root_dir=tmpdir, max_epochs=1, gpus=[0], - distributed_backend='ddp', + distributed_backend='ddp_spawn', precision=16, checkpoint_callback=checkpoint, logger=logger, diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index 8160bf8c72b44..378d7f6a2845d 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -107,7 +107,6 @@ def test_early_stopping_cpu_model(tmpdir): model.unfreeze() -@pytest.mark.spawn @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif((platform.system() == "Darwin" and diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 8401f62070564..5fc34645d34a9 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -11,57 +11,109 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate from torchtext.data import Batch, Dataset, Example, Field, LabelField - PRETEND_N_OF_GPUS = 16 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -@pytest.mark.parametrize('gpus', [1, [0], [1]]) -def test_single_gpu_model(tmpdir, gpus): - """Make sure single GPU works (DP mode).""" +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_multi_gpu_none_backend(tmpdir): + """Make sure when using multiple GPUs the user can't use `distributed_backend = None`.""" + tutils.set_random_master_port() trainer_options = dict( default_root_dir=tmpdir, + distributed_backend=None, progress_bar_refresh_rate=0, max_epochs=1, - limit_train_batches=0.1, - limit_val_batches=0.1, - gpus=gpus + limit_train_batches=0.2, + limit_val_batches=0.2, + gpus=2 + ) + + model = EvalModelTemplate() + tpipes.run_model_test(trainer_options, model) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_multi_gpu_early_stop_ddp_spawn(tmpdir): + """Make sure DDP works. with early stopping""" + tutils.set_random_master_port() + + trainer_options = dict( + default_root_dir=tmpdir, + early_stop_callback=True, + max_epochs=50, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + distributed_backend='ddp_spawn', ) model = EvalModelTemplate() tpipes.run_model_test(trainer_options, model) -@pytest.mark.spawn -@pytest.mark.parametrize("backend", ['dp', 'ddp', 'ddp2']) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_multi_gpu_model(tmpdir, backend): - """Make sure DDP works.""" +def test_multi_gpu_model_dp(tmpdir): tutils.set_random_master_port() trainer_options = dict( default_root_dir=tmpdir, max_epochs=1, - limit_train_batches=0.4, - limit_val_batches=0.2, + limit_train_batches=10, + limit_val_batches=10, gpus=[0, 1], - distributed_backend=backend, + distributed_backend='dp', + progress_bar_refresh_rate=0 ) model = EvalModelTemplate() - # tutils.run_model_test(trainer_options, model) - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - assert result + + tpipes.run_model_test(trainer_options, model) # test memory helper functions memory.get_memory_profile('min_max') -@pytest.mark.spawn -@pytest.mark.parametrize("backend", ['dp', 'ddp', 'ddp2']) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_multi_gpu_early_stop(tmpdir, backend): +def test_multi_gpu_model_ddp_spawn(tmpdir): + tutils.set_random_master_port() + + trainer_options = dict( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + distributed_backend='ddp_spawn', + progress_bar_refresh_rate=0 + ) + + model = EvalModelTemplate() + + tpipes.run_model_test(trainer_options, model) + + # test memory helper functions + memory.get_memory_profile('min_max') + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.parametrize('gpus', [1, [0], [1]]) +def test_single_gpu_model(tmpdir, gpus): + """Make sure single GPU works (DP mode).""" + trainer_options = dict( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=1, + limit_train_batches=0.1, + limit_val_batches=0.1, + gpus=gpus + ) + + model = EvalModelTemplate() + tpipes.run_model_test(trainer_options, model) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_multi_gpu_early_stop_dp(tmpdir): """Make sure DDP works. with early stopping""" tutils.set_random_master_port() @@ -72,17 +124,13 @@ def test_multi_gpu_early_stop(tmpdir, backend): limit_train_batches=10, limit_val_batches=10, gpus=[0, 1], - distributed_backend=backend, + distributed_backend='dp', ) model = EvalModelTemplate() - # tutils.run_model_test(trainer_options, model) - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - assert result + tpipes.run_model_test(trainer_options, model) -@pytest.mark.spawn @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_ddp_all_dataloaders_passed_to_fit(tmpdir): """Make sure DDP works with dataloaders passed to fit()""" @@ -92,10 +140,10 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir): default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=1, - limit_train_batches=0.1, - limit_val_batches=0.1, + limit_train_batches=0.2, + limit_val_batches=0.2, gpus=[0, 1], - distributed_backend='ddp' + distributed_backend='ddp_spawn' ) model = EvalModelTemplate() @@ -107,24 +155,6 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir): assert result == 1, "DDP doesn't work with dataloaders passed to fit()." -@pytest.mark.spawn -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_multi_gpu_none_backend(tmpdir): - """Make sure when using multiple GPUs the user can't use `distributed_backend = None`.""" - trainer_options = dict( - default_root_dir=tmpdir, - progress_bar_refresh_rate=0, - max_epochs=1, - limit_train_batches=0.1, - limit_val_batches=0.1, - gpus='-1' - ) - - model = EvalModelTemplate() - with pytest.warns(UserWarning): - tpipes.run_model_test(trainer_options, model) - - @pytest.fixture def mocked_device_count(monkeypatch): def device_count(): @@ -264,7 +294,7 @@ def test_parse_gpu_fail_on_non_existent_id_2(mocked_device_count): @pytest.mark.gpus_param_tests @pytest.mark.parametrize("gpus", [-1, '-1']) -def test_parse_gpu_returns_None_when_no_devices_are_available(mocked_device_count_0, gpus): +def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_count_0, gpus): with pytest.raises(MisconfigurationException): _parse_gpu_ids(gpus) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 992150bc8dff1..64acf11f79dd1 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -88,6 +88,7 @@ def test_horovod_cpu_implicit(tmpdir): _run_horovod(trainer_options) +@pytest.mark.skipif(True, reason="fix hv") @pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8") @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") @pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support") @@ -101,7 +102,7 @@ def test_horovod_multi_gpu(tmpdir): max_epochs=1, limit_train_batches=0.4, limit_val_batches=0.2, - gpus=1, + gpus=2, deterministic=True, distributed_backend='horovod' ) @@ -128,7 +129,7 @@ def validation_step(self, batch, *args, **kwargs): return super(TestTrainingStepModel, self).validation_step(batch, *args, **kwargs) hparams = EvalModelTemplate.get_default_hparams() - model = TestTrainingStepModel(hparams) + model = TestTrainingStepModel(**hparams) trainer_options = dict( default_root_dir=str(tmpdir), diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 9eb1067322127..ff2b68e2d337b 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -92,9 +92,7 @@ def test_running_test_pretrained_model_cpu(tmpdir): # correct result and ok accuracy assert result == 1, 'training failed to complete' - pretrained_model = tutils.load_model_from_checkpoint( - logger, trainer.checkpoint_callback.dirpath, module_class=EvalModelTemplate - ) + pretrained_model = EvalModelTemplate.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) diff --git a/tests/models/test_test_loop.py b/tests/models/test_test_loop.py new file mode 100644 index 0000000000000..141567e465b44 --- /dev/null +++ b/tests/models/test_test_loop.py @@ -0,0 +1,71 @@ +import os +import pytorch_lightning as pl +from tests.base import EvalModelTemplate +import tests.base.develop_utils as tutils +import torch +import pytest + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_single_gpu_test(tmpdir): + tutils.set_random_master_port() + + model = EvalModelTemplate() + trainer = pl.Trainer( + default_root_dir=os.getcwd(), + max_epochs=2, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0], + ) + trainer.fit(model) + assert 'ckpt' in trainer.checkpoint_callback.best_model_path + results = trainer.test() + assert 'test_acc' in results + + results = trainer.test(model) + assert 'test_acc' in results + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_dp_test(tmpdir): + tutils.set_random_master_port() + + model = EvalModelTemplate() + trainer = pl.Trainer( + default_root_dir=os.getcwd(), + max_epochs=2, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + distributed_backend='dp', + ) + trainer.fit(model) + assert 'ckpt' in trainer.checkpoint_callback.best_model_path + results = trainer.test() + assert 'test_acc' in results + + results = trainer.test(model) + assert 'test_acc' in results + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_ddp_spawn_test(tmpdir): + tutils.set_random_master_port() + + model = EvalModelTemplate() + trainer = pl.Trainer( + default_root_dir=os.getcwd(), + max_epochs=2, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + distributed_backend='ddp_spawn', + ) + trainer.fit(model) + assert 'ckpt' in trainer.checkpoint_callback.best_model_path + results = trainer.test() + assert 'test_acc' in results + + results = trainer.test(model) + assert 'test_acc' in results diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 5fa60311a38c4..5ec2e7e7d9492 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -6,6 +6,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate +import tests.base.develop_pipelines as tpipes try: import torch_xla @@ -19,6 +20,44 @@ TPU_AVAILABLE = True +@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") +@pytest.mark.parametrize("tpu_cores", [1, [1], 8]) +def test_base_tpu_model(tmpdir, tpu_cores): + """Make sure model trains on TPU.""" + trainer_options = dict( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=1, + tpu_cores=tpu_cores, + limit_train_batches=0.4, + limit_val_batches=0.4 + ) + + model = EvalModelTemplate() + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) + + +@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") +@pytest.mark.parametrize("tpu_cores", [1, [1], 8]) +def test_base_tpu_16bit_model(tmpdir, tpu_cores): + """Make sure model trains on TPU.""" + trainer_options = dict( + default_root_dir=tmpdir, + precision=16, + progress_bar_refresh_rate=0, + max_epochs=1, + tpu_cores=tpu_cores, + limit_train_batches=0.4, + limit_val_batches=0.4 + ) + + model = EvalModelTemplate() + + tpipes.run_model_test(trainer_options, model, on_gpu=False) + + assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables" + + @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") @pytest.mark.parametrize(['tpu_cores', 'expected_device'], [ pytest.param([1], 'xla:1'), @@ -60,7 +99,6 @@ def test_single_tpu_core_model(tmpdir, tpu_cores, expected_device): assert torch_xla._XLAC._xla_get_default_device() == expected_device -@pytest.mark.spawn @pytest.mark.parametrize("tpu_cores", [1, 8]) @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") def test_multi_core_tpu_model(tmpdir, tpu_cores): @@ -77,7 +115,6 @@ def test_multi_core_tpu_model(tmpdir, tpu_cores): assert trainer.tpu_id is None -@pytest.mark.spawn @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") def test_dataloaders_passed_to_fit(tmpdir): """Test if dataloaders passed to trainer works on TPU""" @@ -97,24 +134,6 @@ def test_dataloaders_passed_to_fit(tmpdir): assert result, "TPU doesn't work with dataloaders passed to fit()." -@pytest.mark.spawn -@pytest.mark.parametrize("tpu_cores", [1, 8, [1]]) -@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") -def test_mixed_precision_with_tpu(tmpdir, tpu_cores): - """Test if FP16 TPU core training works""" - model = EvalModelTemplate() - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - train_percent_check=0.4, - val_percent_check=0.2, - tpu_cores=tpu_cores, - precision=16 - ) - trainer.fit(model) - assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables" - - @pytest.mark.parametrize(['tpu_cores', 'expected_tpu_id'], [ pytest.param(1, None), pytest.param(8, None), diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index e119c2bff81bb..d5665a12acfe3 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -17,75 +17,75 @@ def _soft_unimport_module(str_module): def test_tbd_remove_in_v0_10_0_trainer(): rnd_val = random.random() - with pytest.deprecated_call(match='v0.10.0'): + with pytest.deprecated_call(match='will be removed in v0.10.0'): trainer = Trainer(overfit_pct=rnd_val) assert trainer.overfit_batches == rnd_val - with pytest.deprecated_call(match='v0.10.0'): + with pytest.deprecated_call(match='will be removed in v0.10.0'): assert trainer.overfit_pct == rnd_val rnd_val = random.random() - with pytest.deprecated_call(match='v0.10.0'): + with pytest.deprecated_call(match='will be removed in v0.10.0'): trainer = Trainer(train_percent_check=rnd_val) assert trainer.limit_train_batches == rnd_val with pytest.deprecated_call(match='v0.10.0'): assert trainer.train_percent_check == rnd_val rnd_val = random.random() - with pytest.deprecated_call(match='v0.10.0'): + with pytest.deprecated_call(match='will be removed in v0.10.0'): trainer = Trainer(val_percent_check=rnd_val) assert trainer.limit_val_batches == rnd_val - with pytest.deprecated_call(match='v0.10.0'): + with pytest.deprecated_call(match='will be removed in v0.10.0'): assert trainer.val_percent_check == rnd_val rnd_val = random.random() - with pytest.deprecated_call(match='v0.10.0'): + with pytest.deprecated_call(match='will be removed in v0.10.0'): trainer = Trainer(test_percent_check=rnd_val) assert trainer.limit_test_batches == rnd_val - with pytest.deprecated_call(match='v0.10.0'): + with pytest.deprecated_call(match='will be removed in v0.10.0'): assert trainer.test_percent_check == rnd_val trainer = Trainer() - with pytest.deprecated_call(match='v0.10.0'): + with pytest.deprecated_call(match='will be removed in v0.10.0'): trainer.proc_rank = 0 - with pytest.deprecated_call(match='v0.10.0'): + with pytest.deprecated_call(match='will be removed in v0.10.0'): assert trainer.proc_rank == trainer.global_rank def test_tbd_remove_in_v0_9_0_trainer(): # test show_progress_bar set by progress_bar_refresh_rate - with pytest.deprecated_call(match='v0.9.0'): + with pytest.deprecated_call(match='will be removed in v0.9.0'): trainer = Trainer(progress_bar_refresh_rate=0, show_progress_bar=True) assert not getattr(trainer, 'show_progress_bar') - with pytest.deprecated_call(match='v0.9.0'): + with pytest.deprecated_call(match='will be removed in v0.9.0'): trainer = Trainer(progress_bar_refresh_rate=50, show_progress_bar=False) assert getattr(trainer, 'show_progress_bar') - with pytest.deprecated_call(match='v0.9.0'): + with pytest.deprecated_call(match='will be removed in v0.9.0'): trainer = Trainer(num_tpu_cores=8) assert trainer.tpu_cores == 8 def test_tbd_remove_in_v0_9_0_module_imports(): _soft_unimport_module("pytorch_lightning.core.decorators") - with pytest.deprecated_call(match='v0.9.0'): + with pytest.deprecated_call(match='will be removed in v0.9.0'): from pytorch_lightning.core.decorators import data_loader # noqa: F811 data_loader(print) _soft_unimport_module("pytorch_lightning.logging.comet") - with pytest.deprecated_call(match='v0.9.0'): + with pytest.deprecated_call(match='will be removed in v0.9.0'): from pytorch_lightning.logging.comet import CometLogger # noqa: F402 _soft_unimport_module("pytorch_lightning.logging.mlflow") - with pytest.deprecated_call(match='v0.9.0'): + with pytest.deprecated_call(match='will be removed in v0.9.0'): from pytorch_lightning.logging.mlflow import MLFlowLogger # noqa: F402 _soft_unimport_module("pytorch_lightning.logging.neptune") - with pytest.deprecated_call(match='v0.9.0'): + with pytest.deprecated_call(match='will be removed in v0.9.0'): from pytorch_lightning.logging.neptune import NeptuneLogger # noqa: F402 _soft_unimport_module("pytorch_lightning.logging.test_tube") - with pytest.deprecated_call(match='v0.9.0'): + with pytest.deprecated_call(match='will be removed in v0.9.0'): from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402 _soft_unimport_module("pytorch_lightning.logging.wandb") - with pytest.deprecated_call(match='v0.9.0'): + with pytest.deprecated_call(match='will be removed in v0.9.0'): from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402 @@ -136,7 +136,7 @@ def test_tbd_remove_in_v1_0_0_model_hooks(): trainer.test(model) assert trainer.callback_metrics == {'test_loss': torch.tensor(0.6)} - with pytest.deprecated_call(match='v1.0'): + with pytest.deprecated_call(match='will be removed in v1.0'): trainer = Trainer(logger=False) # TODO: why `dataloder` is required if it is not used result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1) @@ -144,12 +144,12 @@ def test_tbd_remove_in_v1_0_0_model_hooks(): model = ModelVer0_7() - with pytest.deprecated_call(match='v1.0'): + with pytest.deprecated_call(match='will be removed in v1.0'): trainer = Trainer(logger=False) trainer.test(model) assert trainer.callback_metrics == {'test_loss': torch.tensor(0.7)} - with pytest.deprecated_call(match='v1.0'): + with pytest.deprecated_call(match='will be removed in v1.0'): trainer = Trainer(logger=False) # TODO: why `dataloder` is required if it is not used result = trainer._evaluate(model, dataloaders=[[None]], max_batches=1) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 4d7b3de8e8286..e76ef0e556352 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -326,7 +326,12 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v assert trainer.num_training_batches == limit_train_batches assert trainer.num_val_batches == [limit_val_batches] * len(trainer.val_dataloaders) trainer.test(ckpt_path=None) - assert trainer.num_test_batches == [limit_test_batches] * len(trainer.test_dataloaders) + + # when the limit is greater than the number of test batches it should be the num in loaders + if limit_test_batches > 1e10: + assert trainer.num_test_batches == [len(x) for x in model.test_dataloader()] + else: + assert trainer.num_test_batches == [limit_test_batches] * len(trainer.test_dataloaders) @pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) @@ -534,7 +539,7 @@ class CustomDataLoader(torch.utils.data.DataLoader): def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, - worker_init_fn=None, dummy_kwarg=None): + worker_init_fn=None, dummy_kwarg=None, **kwargs): super().__init__(dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last, timeout, worker_init_fn) @@ -544,7 +549,7 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, trainer = Trainer( gpus=[0, 1], num_nodes=1, - distributed_backend='ddp', + distributed_backend='ddp_spawn', ) class CustomDummyObj: @@ -553,7 +558,8 @@ class CustomDummyObj: result = trainer.auto_add_sampler(CustomDummyObj(), train=True) assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader" - result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000))), train=True) + dataset = list(range(1000)) + result = trainer.auto_add_sampler(CustomDataLoader(dataset), train=True) assert isinstance(result, torch.utils.data.DataLoader) assert isinstance(result, CustomDataLoader) assert hasattr(result, 'dummy_kwarg') diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps.py index 05627f8e1e3b1..6091f486257c7 100644 --- a/tests/trainer/test_trainer_steps.py +++ b/tests/trainer/test_trainer_steps.py @@ -4,7 +4,6 @@ import torch -@pytest.mark.spawn @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_training_step_dict(tmpdir): """ @@ -17,8 +16,6 @@ def test_training_step_dict(tmpdir): trainer = Trainer( default_root_dir=tmpdir, fast_dev_run=True, - precision=16, - gpus=1, weights_summary=None, ) trainer.fit(model) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 48a8f9011811c..276d35f56ade6 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -8,6 +8,33 @@ from tests.base import EvalModelTemplate +def test_num_training_batches(tmpdir): + """ + Tests that the correct number of batches are allocated + """ + # when we have fewer batches in the dataloader we should use those instead of the limit + model = EvalModelTemplate() + trainer = Trainer(limit_val_batches=100, limit_train_batches=100, max_epochs=1) + trainer.fit(model) + + assert len(model.train_dataloader()) == 10 + assert len(model.val_dataloader()) == 10 + assert isinstance(trainer.num_val_batches, list) + assert trainer.num_val_batches[0] == 10 + assert trainer.num_training_batches == 10 + + # when we have more batches in the dataloader we should limit them + model = EvalModelTemplate() + trainer = Trainer(limit_val_batches=7, limit_train_batches=7, max_epochs=1) + trainer.fit(model) + + assert len(model.train_dataloader()) == 10 + assert len(model.val_dataloader()) == 10 + assert isinstance(trainer.num_val_batches, list) + assert trainer.num_val_batches[0] == 7 + assert trainer.num_training_batches == 7 + + def test_overfit_batch_limits(tmpdir): # ------------------------------------------------------ # Make sure shuffle is correct across loaders initially