From 8920d69440ae24c6aee946b9a771a8cfec0dd5f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 11 Aug 2020 16:13:56 +0200 Subject: [PATCH] rename --- tests/callbacks/test_progress_bar.py | 6 +++--- tests/utilities/test_data.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 96745e08ad8b17..347716a8eeb0ef 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -195,7 +195,7 @@ def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx assert progress_bar.test_batches_seen == progress_bar.total_test_batches -@pytest.mark.parametrize('num_sanity_val_steps,num_val_dataloaders_batches,num_sanity_check_run_steps', [ +@pytest.mark.parametrize('num_sanity_val_steps,num_val_dataloaders_batches,expected_num_steps', [ (-1, [10], 10), (0, [10], 0), (2, [10], 2), @@ -206,7 +206,7 @@ def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx (10, [float('inf')], 10), (10, [1, float('inf')], 11), ]) -def test_sanity_check_progress_bar_total(tmpdir, num_sanity_val_steps, num_val_dataloaders_batches, num_sanity_check_run_steps): +def test_sanity_check_progress_bar_total(tmpdir, num_sanity_val_steps, num_val_dataloaders_batches, expected_num_steps): """Test that the sanity_check progress finishes with the correct total steps processed.""" tmp_model = EvalModelTemplate(batch_size=1) @@ -232,4 +232,4 @@ def test_sanity_check_progress_bar_total(tmpdir, num_sanity_val_steps, num_val_d trainer.fit(model, val_dataloaders=val_dataloaders) val_progress_bar = trainer.progress_bar_callback.val_progress_bar - assert getattr(val_progress_bar, 'total', 0) == num_sanity_check_run_steps + assert getattr(val_progress_bar, 'total', 0) == expected_num_steps diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index 31b5ee5ceb5668..b0c31b68a561be 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -14,7 +14,7 @@ reason="IterableDataset with __len__ before 1.4 raises", ) def test_warning_with_iterable_dataset_and_len(tmpdir): - """ Tests that a warning messages is shown when an IterableDataset defines `__len__`. """ + """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """ model = EvalModelTemplate() original_dataset = model.train_dataloader().dataset