Skip to content

Commit

Permalink
prune check on Trainer fit result (#5453)
Browse files Browse the repository at this point in the history
* prune check on Trainer fit result

* flake8

* Apply suggestions from code review

Co-authored-by: Carlos Mocholí <[email protected]>

* .

Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
Borda and carmocca authored Jan 12, 2021
1 parent b9530d2 commit 059f463
Show file tree
Hide file tree
Showing 25 changed files with 212 additions and 201 deletions.
5 changes: 3 additions & 2 deletions tests/backends/test_ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.trainer.states import TrainerState
from tests.base import EvalModelTemplate
from pytorch_lightning.core import memory
from pytorch_lightning.trainer import Trainer
Expand Down Expand Up @@ -81,5 +82,5 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
gpus=[0, 1],
accelerator='ddp_spawn',
)
result = trainer.fit(model, **fit_options)
assert result == 1, "DDP doesn't work with dataloaders passed to fit()."
trainer.fit(model, **fit_options)
assert trainer.state == TrainerState.FINISHED, "DDP doesn't work with dataloaders passed to fit()."
6 changes: 3 additions & 3 deletions tests/backends/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.xla_device import XLADeviceUtils
from tests.base.boring_model import BoringModel
from tests.base.develop_utils import pl_multi_process_test
Expand Down Expand Up @@ -45,9 +46,8 @@ def test_resume_training_on_cpu(tmpdir):
max_epochs=1,
default_root_dir=tmpdir,
)
result = trainer.fit(model)

assert result == 1
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine")
Expand Down
9 changes: 5 additions & 4 deletions tests/base/develop_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import TrainerState
from tests.base import BoringModel
from tests.base.develop_utils import get_default_logger, load_model_from_checkpoint, reset_seed

Expand All @@ -23,10 +24,10 @@ def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50

# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
trainer.fit(model)

# correct result and ok accuracy
assert result == 1, 'amp + ddp model failed to complete'
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

pretrained_model = load_model_from_checkpoint(
trainer.logger,
Expand Down Expand Up @@ -60,10 +61,10 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None,

trainer = Trainer(**trainer_options)
initial_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()])
result = trainer.fit(model)
trainer.fit(model)
post_train_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()])

assert result == 1, 'trainer failed'
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
# Check that the model is actually changed post-training
change_ratio = torch.norm(initial_values - post_train_values)
assert change_ratio > 0.1, f"the model is changed of {change_ratio}"
Expand Down
5 changes: 3 additions & 2 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from pytorch_lightning import _logger, seed_everything, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel, EvalModelTemplate

Expand Down Expand Up @@ -162,9 +163,9 @@ def training_step(self, *args, **kwargs):
overfit_batches=0.20,
max_epochs=10,
)
result = trainer.fit(model)
trainer.fit(model)

assert result == 1, 'training failed to complete'
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
assert trainer.current_epoch < trainer.max_epochs - 1


Expand Down
5 changes: 3 additions & 2 deletions tests/callbacks/test_gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.callbacks import GPUStatsMonitor
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.loggers.csv_logs import ExperimentWriter
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate

Expand All @@ -44,8 +45,8 @@ def test_gpu_stats_monitor(tmpdir):
logger=logger
)

results = trainer.fit(model)
assert results
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

path_csv = os.path.join(logger.log_dir, ExperimentWriter.NAME_METRICS_FILE)
met_data = np.genfromtxt(path_csv, delimiter=',', names=True, deletechars='', replace_space=' ')
Expand Down
25 changes: 13 additions & 12 deletions tests/callbacks/test_lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel, EvalModelTemplate

Expand All @@ -36,8 +37,8 @@ def test_lr_monitor_single_lr(tmpdir):
limit_train_batches=0.5,
callbacks=[lr_monitor],
)
result = trainer.fit(model)
assert result
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

assert lr_monitor.lrs, 'No learning rates logged'
assert all(v is None for v in lr_monitor.last_momentum_values.values()), \
Expand Down Expand Up @@ -78,8 +79,8 @@ def configure_optimizers(self):
log_every_n_steps=1,
callbacks=[lr_monitor],
)
result = trainer.fit(model)
assert result
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

assert all(v is not None for v in lr_monitor.last_momentum_values.values()), \
'Expected momentum to be logged'
Expand Down Expand Up @@ -110,8 +111,8 @@ def configure_optimizers(self):
callbacks=[lr_monitor],
)
with pytest.warns(RuntimeWarning, match="optimizers do not have momentum."):
result = trainer.fit(model)
assert result
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), \
'Expected momentum to be logged'
Expand All @@ -136,8 +137,8 @@ def test_lr_monitor_no_lr_scheduler(tmpdir):
)

with pytest.warns(RuntimeWarning, match='have no learning rate schedulers'):
result = trainer.fit(model)
assert result
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


def test_lr_monitor_no_logger(tmpdir):
Expand Down Expand Up @@ -176,8 +177,8 @@ def test_lr_monitor_multi_lrs(tmpdir, logging_interval):
limit_val_batches=0.1,
callbacks=[lr_monitor],
)
result = trainer.fit(model)
assert result
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

assert lr_monitor.lrs, 'No learning rates logged'
assert len(lr_monitor.lrs) == len(trainer.lr_schedulers), \
Expand Down Expand Up @@ -209,8 +210,8 @@ def test_lr_monitor_param_groups(tmpdir):
limit_train_batches=0.5,
callbacks=[lr_monitor],
)
result = trainer.fit(model)
assert result
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

assert lr_monitor.lrs, 'No learning rates logged'
assert len(lr_monitor.lrs) == 2 * len(trainer.lr_schedulers), \
Expand Down
5 changes: 3 additions & 2 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel
Expand Down Expand Up @@ -190,8 +191,8 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
callbacks=[model_checkpoint],
max_epochs=num_epochs,
)
result = trainer.fit(model)
assert 1 == result
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


def test_model_checkpoint_format_checkpoint_name(tmpdir):
Expand Down
29 changes: 15 additions & 14 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch.utils.data import DataLoader, random_split

from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.trainer.states import TrainerState
from tests.base import EvalModelTemplate
from tests.base.datasets import TrialMNIST
from tests.base.datamodules import TrialMNISTDataModule
Expand Down Expand Up @@ -206,8 +207,8 @@ def test_train_loop_only(tmpdir):
)

# fit model
result = trainer.fit(model, dm)
assert result == 1
trainer.fit(model, dm)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
assert trainer.logger_connector.callback_metrics['loss'] < 0.6


Expand All @@ -228,8 +229,8 @@ def test_train_val_loop_only(tmpdir):
)

# fit model
result = trainer.fit(model, dm)
assert result == 1
trainer.fit(model, dm)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
assert trainer.logger_connector.callback_metrics['loss'] < 0.6


Expand All @@ -247,8 +248,8 @@ def test_dm_checkpoint_save(tmpdir):
)

# fit model
result = trainer.fit(model, dm)
assert result
trainer.fit(model, dm)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
checkpoint = torch.load(checkpoint_path)
assert dm.__class__.__name__ in checkpoint
Expand Down Expand Up @@ -285,8 +286,8 @@ def test_full_loop(tmpdir):
)

# fit model
result = trainer.fit(model, dm)
assert result == 1
trainer.fit(model, dm)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

# test
result = trainer.test(datamodule=dm)
Expand All @@ -309,8 +310,8 @@ def test_trainer_attached_to_dm(tmpdir):
)

# fit model
result = trainer.fit(model, dm)
assert result == 1
trainer.fit(model, dm)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
assert dm.trainer is not None

# test
Expand All @@ -336,8 +337,8 @@ def test_full_loop_single_gpu(tmpdir):
)

# fit model
result = trainer.fit(model, dm)
assert result == 1
trainer.fit(model, dm)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

# test
result = trainer.test(datamodule=dm)
Expand All @@ -363,8 +364,8 @@ def test_full_loop_dp(tmpdir):
)

# fit model
result = trainer.fit(model, dm)
assert result == 1
trainer.fit(model, dm)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

# test
result = trainer.test(datamodule=dm)
Expand Down
5 changes: 3 additions & 2 deletions tests/core/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.core.step_result import Result, EvalResult
import tests.base.develop_utils as tutils
from pytorch_lightning.trainer.states import TrainerState

from tests.base import EvalModelTemplate
from tests.base.datamodules import TrialMNISTDataModule
Expand Down Expand Up @@ -121,8 +122,8 @@ def test_result_obj_predictions(tmpdir, test_option, do_train, gpus):
assert not prediction_file.exists()

if do_train:
result = trainer.fit(model, dm)
assert result == 1
trainer.fit(model, dm)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
result = trainer.test(datamodule=dm)
result = result[0]
assert result['test_loss'] < 0.6
Expand Down
5 changes: 3 additions & 2 deletions tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
WandbLogger,
)
from pytorch_lightning.loggers.base import DummyExperiment
from pytorch_lightning.trainer.states import TrainerState
from tests.base import BoringModel
from tests.loggers.test_comet import _patch_comet_atexit
from tests.loggers.test_mlflow import mock_mlflow_run_creation
Expand Down Expand Up @@ -343,8 +344,8 @@ def _test_logger_created_on_rank_zero_only(tmpdir, logger_class):
checkpoint_callback=True,
callbacks=[RankZeroLoggerCheck()],
)
result = trainer.fit(model)
assert result == 1
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


def test_logger_with_prefix_all(tmpdir, monkeypatch):
Expand Down
9 changes: 5 additions & 4 deletions tests/loggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
from pytorch_lightning.loggers.base import DummyExperiment, DummyLogger
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import rank_zero_only
from tests.base import BoringModel

Expand Down Expand Up @@ -108,8 +109,8 @@ def training_step(self, batch, batch_idx):
logger=logger,
default_root_dir=tmpdir,
)
result = trainer.fit(model)
assert result, "Training failed"
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
assert logger.hparams_logged == model.hparams
assert logger.metrics_logged != {}
assert logger.finalized_status == "success"
Expand All @@ -133,8 +134,8 @@ def training_step(self, batch, batch_idx):
logger=[logger1, logger2],
default_root_dir=tmpdir,
)
result = trainer.fit(model)
assert result == 1, "Training failed"
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

assert logger1.hparams_logged == model.hparams
assert logger1.metrics_logged != {}
Expand Down
6 changes: 4 additions & 2 deletions tests/models/data/horovod/train_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import sys

# this is need as e.g. Conda do not uses `PYTHONPATH` env var as pip or/and virtualenv
from pytorch_lightning.trainer.states import TrainerState

sys.path = os.getenv('PYTHONPATH').split(':') + sys.path

from pytorch_lightning import Trainer # noqa: E402
Expand Down Expand Up @@ -54,8 +56,8 @@ def run_test_from_config(trainer_options):
model = EvalModelTemplate()

trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

# Horovod should be initialized following training. If not, this will raise an exception.
assert hvd.size() == 2
Expand Down
Loading

0 comments on commit 059f463

Please sign in to comment.