diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5116df9fc0786..9b451e8ef88d6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -3,6 +3,7 @@ import warnings import logging as log from typing import Union, Optional, List, Dict, Tuple, Iterable +from argparse import ArgumentParser import torch import torch.distributed as dist @@ -38,19 +39,19 @@ try: from apex import amp - - APEX_AVAILABLE = True except ImportError: APEX_AVAILABLE = False +else: + APEX_AVAILABLE = True try: import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp - - XLA_AVAILABLE = True except ImportError: XLA_AVAILABLE = False +else: + XLA_AVAILABLE = True class Trainer(TrainerIOMixin, @@ -71,7 +72,7 @@ def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, - early_stop_callback: Optional[Union[EarlyStopping, bool]] = None, + early_stop_callback: Optional[Union[EarlyStopping, bool]] = False, callbacks: List[Callback] = [], default_save_path: Optional[str] = None, gradient_clip_val: float = 0, @@ -98,7 +99,7 @@ def __init__( train_percent_check: float = 1.0, val_percent_check: float = 1.0, test_percent_check: float = 1.0, - val_check_interval: Union[float] = 1.0, + val_check_interval: float = 1.0, log_save_interval: int = 100, row_log_interval: int = 10, add_row_log_interval=None, # backward compatible, todo: remove in v0.8.0 @@ -116,6 +117,7 @@ def __init__( profiler: Optional[BaseProfiler] = None, benchmark: bool = False, reload_dataloaders_every_epoch: bool = False, + **kwargs ): r""" @@ -153,8 +155,9 @@ def __init__( trainer = Trainer(checkpoint_callback=checkpoint_callback) - early_stop_callback: Callback for early stopping. If - set to ``True``, then the default callback monitoring ``'val_loss'`` is created. + early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`): + Callback for early stopping. + If set to ``True``, then a default callback monitoring ``'val_loss'`` is created. Will raise an error if ``'val_loss'`` is not found. If set to ``False``, then early stopping will be disabled. If set to ``None``, then the default callback monitoring ``'val_loss'`` is created. @@ -626,6 +629,7 @@ def on_train_end(self): # Transfer params # Backward compatibility + self.num_nodes = num_nodes if nb_gpu_nodes is not None: warnings.warn("`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) @@ -746,10 +750,12 @@ def on_train_end(self): self.weights_save_path = weights_save_path # accumulated grads + self.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) # allow int, string and gpu list - self.data_parallel_device_ids = parse_gpu_ids(gpus) + self.gpus = gpus + self.data_parallel_device_ids = parse_gpu_ids(self.gpus) self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids) # tpu state flags @@ -796,13 +802,17 @@ def on_train_end(self): self.row_log_interval = row_log_interval # how much of the data to use + self.overfit_pct = overfit_pct self.determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct) # 16 bit mixed precision training using apex self.amp_level = amp_level self.precision = precision - if self.precision == 16: + + assert self.precision == 32 or self.precision == 16, 'only 32 or 16 bit precision supported' + + if self.precision == 16 and num_tpu_cores is None: use_amp = True self.init_amp(use_amp) @@ -818,6 +828,28 @@ def slurm_job_id(self) -> int: job_id = None return job_id + @classmethod + def default_attributes(cls): + return vars(cls()) + + @classmethod + def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: + """Extend existing argparse by default `Trainer` attributes.""" + parser = ArgumentParser(parents=[parent_parser]) + + trainer_default_params = Trainer.default_attributes() + + for arg in trainer_default_params: + parser.add_argument('--{0}'.format(arg), default=trainer_default_params[arg], dest=arg) + + return parser + + @classmethod + def from_argparse_args(cls, args): + + params = vars(args) + return cls(**params) + def __parse_gpu_ids(self, gpus): """Parse GPUs id. @@ -937,8 +969,8 @@ def fit( # feed to .fit() """ - # Fit begin callbacks - self.on_fit_start() + # bind logger + model.logger = self.logger # set up the passed in dataloaders (if needed) self.__set_fit_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders) @@ -954,8 +986,18 @@ def fit( task = int(os.environ['SLURM_LOCALID']) self.ddp_train(task, model) else: + self.__set_random_port() + + # track for predict + self.model = model + + # train mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,)) + # load weights if not interrupted + self.load_spawn_weights(model) + self.model = model + # 1 gpu or dp option triggers training using DP module # easier to avoid NCCL issues elif self.use_dp: @@ -969,8 +1011,17 @@ def fit( # COLAB_GPU is an env var available by default in Colab environments. start_method = 'fork' if os.getenv('COLAB_GPU') else 'spawn' + + # track for predict + self.model = model + + # train xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method) + # load weights if not interrupted + self.load_spawn_weights(model) + self.model = model + # ON CPU else: # run through amp wrapper @@ -983,13 +1034,22 @@ def fit( self.run_pretrain_routine(model) - # Fit end callbacks - self.on_fit_end() - # return 1 when finished # used for testing or when we need to know that training succeeded return 1 + def __set_random_port(self): + """ + When running DDP NOT managed by SLURM, the ports might collide + :return: + """ + try: + default_port = os.environ['MASTER_PORT'] + except Exception: + import random + default_port = random.randint(10000, 19000) + os.environ['MASTER_PORT'] = str(default_port) + def __set_fit_dataloaders(self, model, train_dataloader, val_dataloaders, test_dataloaders): # when dataloader is passed via fit, patch the train_dataloader # functions to overwrite with these implementations @@ -998,30 +1058,21 @@ def __set_fit_dataloaders(self, model, train_dataloader, val_dataloaders, test_d m = 'You called .fit() with a train_dataloader but did not define training_step()' raise MisconfigurationException(m) - def patch_train_dataloader(): - return train_dataloader - - model.train_dataloader = patch_train_dataloader + model.train_dataloader = _PatchDataLoader(train_dataloader) if val_dataloaders is not None: if not self.is_overriden('validation_step', model): m = 'You called .fit() with a val_dataloaders but did not define validation_step()' raise MisconfigurationException(m) - def patch_val_dataloader(): - return val_dataloaders - - model.val_dataloader = patch_val_dataloader + model.val_dataloader = _PatchDataLoader(val_dataloaders) if test_dataloaders is not None: if not self.is_overriden('test_step', model): m = 'You called .fit() with a test_dataloaders but did not define test_step()' raise MisconfigurationException(m) - def patch_test_dataloader(): - return test_dataloaders - - model.test_dataloader = patch_test_dataloader + model.test_dataloader = _PatchDataLoader(test_dataloaders) def init_optimizers( self, @@ -1065,10 +1116,8 @@ def run_pretrain_routine(self, model: LightningModule): # set local properties on the model self.copy_trainer_model_properties(ref_model) - # link up experiment object + # log hyper-parameters if self.logger is not None: - ref_model.logger = self.logger - # save exp to get started if hasattr(ref_model, "hparams"): self.logger.log_hyperparams(ref_model.hparams) @@ -1090,7 +1139,8 @@ def run_pretrain_routine(self, model: LightningModule): self.register_slurm_signal_handlers() # print model summary - if self.proc_rank == 0 and self.weights_summary is not None: + # TODO: remove self.testing condition because model.summarize() is wiping out the weights + if self.proc_rank == 0 and self.weights_summary is not None and not self.testing: if self.weights_summary in ['full', 'top']: ref_model.summarize(mode=self.weights_summary) else: @@ -1110,23 +1160,18 @@ def run_pretrain_routine(self, model: LightningModule): # when testing requested only run test and return if self.testing: # only load test dataloader for testing - self.reset_test_dataloader(ref_model) + # self.reset_test_dataloader(ref_model) self.run_evaluation(test_mode=True) return - # load the dataloaders - self.reset_train_dataloader(ref_model) - self.reset_val_dataloader(ref_model) - # check if we should run validation during training - self.disable_validation = self.num_val_batches == 0 or not self.is_overriden('validation_step') - self.disable_validation = self.disable_validation and not self.fast_dev_run + self.disable_validation = not self.is_overriden('validation_step') and not self.fast_dev_run # run tiny validation (if validation defined) # to make sure program won't crash during val ref_model.on_sanity_check_start() - ref_model.on_train_start() if not self.disable_validation and self.num_sanity_val_steps > 0: + self.reset_val_dataloader(ref_model) # init progress bars for validation sanity check pbar = tqdm(desc='Validation sanity check', total=self.num_sanity_val_steps * len(self.val_dataloaders), @@ -1168,7 +1213,7 @@ def test(self, model: Optional[LightningModule] = None): Separates from fit to make sure you never run on your test set until you want to. Args: - model: The model to test. + model (:class:`.LightningModule`): The model to test. Example:: @@ -1186,10 +1231,36 @@ def test(self, model: Optional[LightningModule] = None): trainer = Trainer() trainer.test(model) """ + self.testing = True if model is not None: + self.model = model self.fit(model) - self.run_evaluation(test_mode=True) + elif self.use_ddp or self.use_tpu: + # attempt to load weights from a spawn + path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt') + test_model = self.model + if os.path.exists(path): + test_model = self.load_spawn_weights(self.model) + + self.fit(test_model) + else: + self.run_evaluation(test_mode=True) + + +class _PatchDataLoader(object): + r''' + Callable object for patching dataloaders passed into trainer.fit(). + Use this class to override model.*_dataloader() and be pickle-compatible. + + Args: + dataloader: Dataloader object to return when called. + ''' + def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): + self.dataloader = dataloader + + def __call__(self) -> Union[List[DataLoader], DataLoader]: + return self.dataloader def _set_dataloader(model, dataloader, attribute): diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 7fa4059afc3e2..908bddf0191d0 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -30,8 +30,6 @@ def print_nan_gradients(self): log.info(param, param.grad) def configure_accumulated_gradients(self, accumulate_grad_batches): - self.accumulate_grad_batches = None - if isinstance(accumulate_grad_batches, dict): self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) elif isinstance(accumulate_grad_batches, int): diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 7850638475ad7..d95fce3e5c3bb 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1,10 +1,11 @@ import math import os - import pytest import torch +import argparse import tests.models.utils as tutils +from unittest import mock from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ( EarlyStopping, @@ -16,11 +17,6 @@ LightEmptyTestStep, LightValidationStepMixin, LightValidationMultipleDataloadersMixin, - LightTestMultipleDataloadersMixin, - LightTestFitSingleTestDataloadersMixin, - LightTestFitMultipleTestDataloadersMixin, - LightValStepFitMultipleDataloadersMixin, - LightValStepFitSingleDataloaderMixin, LightTrainDataloader, LightTestDataloader, LightValidationMixin, @@ -66,8 +62,10 @@ class CurrentTestModel(LightTrainDataloader, TestModelBase): # load new model tags_path = tutils.get_data_path(logger, path_dir=tmpdir) tags_path = os.path.join(tags_path, 'meta_tags.csv') - model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path, - tags_csv=tags_path) + model_2 = LightningTestModel.load_from_checkpoint( + checkpoint_path=new_weights_path, + tags_csv=tags_path + ) model_2.eval() @@ -104,8 +102,10 @@ class CurrentTestModel(LightTrainDataloader, LightValidationStepMixin, TestModel # load new model tags_path = tutils.get_data_path(logger, path_dir=tmpdir) tags_path = os.path.join(tags_path, 'meta_tags.csv') - model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path, - tags_csv=tags_path) + model_2 = LightningTestModel.load_from_checkpoint( + checkpoint_path=new_weights_path, + tags_csv=tags_path + ) model_2.eval() @@ -258,7 +258,7 @@ def mock_save_function(filepath): # verify correct naming for i in range(0, len(losses)): - assert f'_ckpt_epoch_{i}.ckpt' in file_lists + assert f"_ckpt_epoch_{i}.ckpt" in file_lists save_dir = tmp_path / "2" save_dir.mkdir() @@ -307,7 +307,7 @@ def mock_save_function(filepath): # make sure other files don't get deleted checkpoint_callback = ModelCheckpoint(save_dir, save_top_k=2, verbose=1) - open(f'{save_dir}/other_file.ckpt', 'a').close() + open(f"{save_dir}/other_file.ckpt", 'a').close() checkpoint_callback.save_function = mock_save_function trainer = Trainer() @@ -380,44 +380,6 @@ def test_model_freeze_unfreeze(): model.unfreeze() -def test_multiple_val_dataloader(tmpdir): - """Verify multiple val_dataloader.""" - tutils.reset_seed() - - class CurrentTestModel( - LightTrainDataloader, - LightValidationMultipleDataloadersMixin, - TestModelBase, - ): - pass - - hparams = tutils.get_hparams() - model = CurrentTestModel(hparams) - - # logger file to get meta - trainer_options = dict( - default_save_path=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=1.0, - ) - - # fit model - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - - # verify training completed - assert result == 1 - - # verify there are 2 val loaders - assert len(trainer.val_dataloaders) == 2, \ - 'Multiple val_dataloaders not initiated properly' - - # make sure predictions are good for each val set - for dataloader in trainer.val_dataloaders: - tutils.run_prediction(dataloader, trainer.model) - - def test_resume_from_checkpoint_epoch_restored(tmpdir): """Verify resuming from checkpoint runs the right number of epochs""" import types @@ -486,221 +448,6 @@ def increment_batch(self, _): assert state['global_step'] + next_model.num_batches_seen == training_batches * 4 -def test_multiple_test_dataloader(tmpdir): - """Verify multiple test_dataloader.""" - tutils.reset_seed() - - class CurrentTestModel( - LightTrainDataloader, - LightTestMultipleDataloadersMixin, - LightEmptyTestStep, - TestModelBase, - ): - pass - - hparams = tutils.get_hparams() - model = CurrentTestModel(hparams) - - # logger file to get meta - trainer_options = dict( - default_save_path=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - - # fit model - trainer = Trainer(**trainer_options) - trainer.fit(model) - trainer.test() - - # verify there are 2 val loaders - assert len(trainer.test_dataloaders) == 2, \ - 'Multiple test_dataloaders not initiated properly' - - # make sure predictions are good for each test set - for dataloader in trainer.test_dataloaders: - tutils.run_prediction(dataloader, trainer.model) - - # run the test method - trainer.test() - - -def test_train_dataloaders_passed_to_fit(tmpdir): - """ Verify that train dataloader can be passed to fit """ - tutils.reset_seed() - - class CurrentTestModel(LightTrainDataloader, TestModelBase): - pass - - hparams = tutils.get_hparams() - - # logger file to get meta - trainer_options = dict( - default_save_path=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - - # only train passed to fit - model = CurrentTestModel(hparams) - trainer = Trainer(**trainer_options) - fit_options = dict(train_dataloader=model._dataloader(train=True)) - results = trainer.fit(model, **fit_options) - - -def test_train_val_dataloaders_passed_to_fit(tmpdir): - """ Verify that train & val dataloader can be passed to fit """ - tutils.reset_seed() - - class CurrentTestModel( - LightTrainDataloader, - LightValStepFitSingleDataloaderMixin, - TestModelBase, - ): - pass - - hparams = tutils.get_hparams() - - # logger file to get meta - trainer_options = dict( - default_save_path=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - - # train, val passed to fit - model = CurrentTestModel(hparams) - trainer = Trainer(**trainer_options) - fit_options = dict(train_dataloader=model._dataloader(train=True), - val_dataloaders=model._dataloader(train=False)) - - results = trainer.fit(model, **fit_options) - assert len(trainer.val_dataloaders) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - - -def test_all_dataloaders_passed_to_fit(tmpdir): - """ Verify train, val & test dataloader can be passed to fit """ - tutils.reset_seed() - - class CurrentTestModel( - LightTrainDataloader, - LightValStepFitSingleDataloaderMixin, - LightTestFitSingleTestDataloadersMixin, - LightEmptyTestStep, - TestModelBase, - ): - pass - - hparams = tutils.get_hparams() - - # logger file to get meta - trainer_options = dict( - default_save_path=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - - # train, val and test passed to fit - model = CurrentTestModel(hparams) - trainer = Trainer(**trainer_options) - fit_options = dict(train_dataloader=model._dataloader(train=True), - val_dataloaders=model._dataloader(train=False), - test_dataloaders=model._dataloader(train=False)) - - results = trainer.fit(model, **fit_options) - - trainer.test() - - assert len(trainer.val_dataloaders) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - assert len(trainer.test_dataloaders) == 1, \ - f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' - - -def test_multiple_dataloaders_passed_to_fit(tmpdir): - """ Verify that multiple val & test dataloaders can be passed to fit """ - tutils.reset_seed() - - class CurrentTestModel( - LightningTestModel, - LightValStepFitMultipleDataloadersMixin, - LightTestFitMultipleTestDataloadersMixin, - ): - pass - - hparams = tutils.get_hparams() - - # logger file to get meta - trainer_options = dict( - default_save_path=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - - # train, multiple val and multiple test passed to fit - model = CurrentTestModel(hparams) - trainer = Trainer(**trainer_options) - fit_options = dict(train_dataloader=model._dataloader(train=True), - val_dataloaders=[model._dataloader(train=False), - model._dataloader(train=False)], - test_dataloaders=[model._dataloader(train=False), - model._dataloader(train=False)]) - results = trainer.fit(model, **fit_options) - trainer.test() - - assert len(trainer.val_dataloaders) == 2, \ - f'Multiple `val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - assert len(trainer.test_dataloaders) == 2, \ - f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' - - -def test_mixing_of_dataloader_options(tmpdir): - """Verify that dataloaders can be passed to fit""" - tutils.reset_seed() - - class CurrentTestModel( - LightTrainDataloader, - LightValStepFitSingleDataloaderMixin, - LightTestFitSingleTestDataloadersMixin, - TestModelBase, - ): - pass - - hparams = tutils.get_hparams() - model = CurrentTestModel(hparams) - - # logger file to get meta - trainer_options = dict( - default_save_path=tmpdir, - max_epochs=1, - val_percent_check=0.1, - train_percent_check=0.2 - ) - - # fit model - trainer = Trainer(**trainer_options) - fit_options = dict(val_dataloaders=model._dataloader(train=False)) - results = trainer.fit(model, **fit_options) - - # fit model - trainer = Trainer(**trainer_options) - fit_options = dict(val_dataloaders=model._dataloader(train=False), - test_dataloaders=model._dataloader(train=False)) - _ = trainer.fit(model, **fit_options) - trainer.test() - - assert len(trainer.val_dataloaders) == 1, \ - f'`val_dataloaders` not initiated properly, got {trainer.val_dataloaders}' - assert len(trainer.test_dataloaders) == 1, \ - f'`test_dataloaders` not initiated properly, got {trainer.test_dataloaders}' - - def _init_steps_model(): """private method for initializing a model with 5% train epochs""" tutils.reset_seed() @@ -855,121 +602,21 @@ def test_end(self, outputs): model = LightningTestModel(hparams) Trainer().test(model) +@mock.patch('argparse.ArgumentParser.parse_args', + return_value=argparse.Namespace(**Trainer.default_attributes())) +def test_default_args(tmpdir): + """Tests default argument parser for Trainer""" + tutils.reset_seed() -def test_trainer_callback_system(tmpdir): - """Test the callback system.""" - - class CurrentTestModel( - LightTrainDataloader, - LightTestMixin, - LightValidationMixin, - TestModelBase, - ): - pass - - hparams = tutils.get_hparams() - model = CurrentTestModel(hparams) - - class TestCallback(Callback): - def __init__(self): - super().__init__() - self.on_init_start_called = False - self.on_init_end_called = False - self.on_fit_start_called = False - self.on_fit_end_called = False - self.on_epoch_start_called = False - self.on_epoch_end_called = False - self.on_batch_start_called = False - self.on_batch_end_called = False - self.on_train_start_called = False - self.on_train_end_called = False - self.on_validation_start_called = False - self.on_validation_end_called = False - self.on_test_start_called = False - self.on_test_end_called = False - - def on_init_start(self, trainer, pl_module): - self.on_init_start_called = True - - def on_init_end(self, trainer, pl_module): - self.on_init_end_called = True - - def on_fit_start(self, trainer, pl_module): - self.on_fit_start_called = True - - def on_fit_end(self, trainer, pl_module): - self.on_fit_end_called = True - - def on_epoch_start(self, trainer, pl_module): - self.on_epoch_start_called = True - - def on_epoch_end(self, trainer, pl_module): - self.on_epoch_end_called = True - - def on_batch_start(self, trainer, pl_module): - self.on_batch_start_called = True - - def on_batch_end(self, trainer, pl_module): - self.on_batch_end_called = True - - def on_train_start(self, trainer, pl_module): - self.on_train_start_called = True - - def on_train_end(self, trainer, pl_module): - self.on_train_end_called = True - - def on_validation_start(self, trainer, pl_module): - self.on_validation_start_called = True - - def on_validation_end(self, trainer, pl_module): - self.on_validation_end_called = True - - def on_test_start(self, trainer, pl_module): - self.on_test_start_called = True - - def on_test_end(self, trainer, pl_module): - self.on_test_end_called = True - - test_callback = TestCallback() - - trainer_options = {} - trainer_options['callbacks'] = [test_callback] - trainer_options['max_epochs'] = 1 - trainer_options['val_percent_check'] = 0.1 - trainer_options['train_percent_check'] = 0.2 - trainer_options['show_progress_bar'] = False - - assert not test_callback.on_init_start_called - assert not test_callback.on_init_end_called - - # fit model - trainer = Trainer(**trainer_options) + # logger file to get meta + logger = tutils.get_test_tube_logger(tmpdir, False) - assert trainer.callbacks[0] == test_callback - assert test_callback.on_init_start_called - assert test_callback.on_init_end_called - assert not test_callback.on_fit_start_called - assert not test_callback.on_fit_start_called + parser = argparse.ArgumentParser(add_help=False) + args = parser.parse_args() + args.logger = logger - trainer.fit(model) + args.max_epochs = 5 + trainer = Trainer.from_argparse_args(args) - assert test_callback.on_fit_start_called - assert test_callback.on_fit_end_called - assert test_callback.on_epoch_start_called - assert test_callback.on_epoch_start_called - assert test_callback.on_batch_start_called - assert test_callback.on_batch_end_called - assert test_callback.on_train_start_called - assert test_callback.on_train_end_called - assert test_callback.on_validation_start_called - assert test_callback.on_validation_end_called - assert not test_callback.on_test_start_called - assert not test_callback.on_test_end_called - - trainer.test() - - assert test_callback.on_test_start_called - assert test_callback.on_test_end_called - -# if __name__ == '__main__': -# pytest.main([__file__]) + assert isinstance(trainer, Trainer) + assert trainer.max_epochs == 5