Skip to content

Commit

Permalink
rename
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored and ananyahjha93 committed Aug 11, 2020
1 parent 562a1e0 commit 8920d69
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx
assert progress_bar.test_batches_seen == progress_bar.total_test_batches


@pytest.mark.parametrize('num_sanity_val_steps,num_val_dataloaders_batches,num_sanity_check_run_steps', [
@pytest.mark.parametrize('num_sanity_val_steps,num_val_dataloaders_batches,expected_num_steps', [
(-1, [10], 10),
(0, [10], 0),
(2, [10], 2),
Expand All @@ -206,7 +206,7 @@ def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx
(10, [float('inf')], 10),
(10, [1, float('inf')], 11),
])
def test_sanity_check_progress_bar_total(tmpdir, num_sanity_val_steps, num_val_dataloaders_batches, num_sanity_check_run_steps):
def test_sanity_check_progress_bar_total(tmpdir, num_sanity_val_steps, num_val_dataloaders_batches, expected_num_steps):
"""Test that the sanity_check progress finishes with the correct total steps processed."""

tmp_model = EvalModelTemplate(batch_size=1)
Expand All @@ -232,4 +232,4 @@ def test_sanity_check_progress_bar_total(tmpdir, num_sanity_val_steps, num_val_d
trainer.fit(model, val_dataloaders=val_dataloaders)

val_progress_bar = trainer.progress_bar_callback.val_progress_bar
assert getattr(val_progress_bar, 'total', 0) == num_sanity_check_run_steps
assert getattr(val_progress_bar, 'total', 0) == expected_num_steps
2 changes: 1 addition & 1 deletion tests/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
reason="IterableDataset with __len__ before 1.4 raises",
)
def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning messages is shown when an IterableDataset defines `__len__`. """
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
model = EvalModelTemplate()
original_dataset = model.train_dataloader().dataset

Expand Down

0 comments on commit 8920d69

Please sign in to comment.