-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Fix the progress bar for the sanity check #2892
Changes from all commits
0e591aa
0a9342a
f87074f
8b1fde4
a6be809
8041aa9
117c1fe
5987308
e978a58
03191e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
from tqdm import tqdm | ||
|
||
from pytorch_lightning.callbacks import Callback | ||
from pytorch_lightning.utilities.data import has_len | ||
|
||
|
||
class ProgressBarBase(Callback): | ||
|
@@ -293,7 +294,9 @@ def init_test_tqdm(self) -> tqdm: | |
def on_sanity_check_start(self, trainer, pl_module): | ||
super().on_sanity_check_start(trainer, pl_module) | ||
self.val_progress_bar = self.init_sanity_tqdm() | ||
self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders)) | ||
self.val_progress_bar.total = sum( | ||
min(trainer.num_sanity_val_steps, len(d) if has_len(d) else float('inf')) for d in trainer.val_dataloaders | ||
) | ||
Comment on lines
+297
to
+299
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @awaelchli There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this relevant here? I thought this pr is just about displaying the num_sanity steps that the trainer returns. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it still has some issues with limit_val_batches and I think a better fix would be to set up num_sanity_val_steps as a list in Trainer itself rather than doing it here, and simple we can do a sum to get total sanity val steps. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does that means
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest in case of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rohitgr7 I like your suggestions. It is true, the trainer should compute these properties and the progress bars should only read them (and maybe sum them). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should I open another PR or keep this PR going? Should we use the same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am already working on it :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you. |
||
self.main_progress_bar = tqdm(disable=True) # dummy progress bar | ||
|
||
def on_sanity_check_end(self, trainer, pl_module): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from distutils.version import LooseVersion | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
from pytorch_lightning.utilities import rank_zero_warn | ||
|
||
try: | ||
from torch.utils.data import IterableDataset | ||
ITERABLE_DATASET_EXISTS = True | ||
except ImportError: | ||
ITERABLE_DATASET_EXISTS = False | ||
|
||
|
||
def has_iterable_dataset(dataloader: DataLoader): | ||
return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \ | ||
and isinstance(dataloader.dataset, IterableDataset) | ||
|
||
|
||
def has_len(dataloader: DataLoader) -> bool: | ||
""" Checks if a given Dataloader has __len__ method implemented i.e. if | ||
it is a finite dataloader or infinite dataloader. """ | ||
|
||
try: | ||
# try getting the length | ||
if len(dataloader) == 0: | ||
raise ValueError('`Dataloader` returned 0 length.' | ||
' Please make sure that your Dataloader at least returns 1 batch') | ||
has_len = True | ||
except TypeError: | ||
has_len = False | ||
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used | ||
has_len = False | ||
|
||
if has_len and has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): | ||
rank_zero_warn( | ||
'Your `IterableDataset` has `__len__` defined.' | ||
' In combination with multi-processing data loading (e.g. batch size > 1),' | ||
' this can lead to unintended side effects since the samples will be duplicated.' | ||
) | ||
return has_len |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import pytest | ||
import torch | ||
from packaging.version import parse | ||
from torch.utils.data.dataloader import DataLoader | ||
from torch.utils.data.dataset import IterableDataset | ||
|
||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len | ||
from tests.base import EvalModelTemplate | ||
|
||
|
||
@pytest.mark.xfail( | ||
parse(torch.__version__) < parse("1.4.0"), | ||
reason="IterableDataset with __len__ before 1.4 raises", | ||
) | ||
def test_warning_with_iterable_dataset_and_len(tmpdir): | ||
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """ | ||
model = EvalModelTemplate() | ||
original_dataset = model.train_dataloader().dataset | ||
|
||
class IterableWithLen(IterableDataset): | ||
|
||
def __iter__(self): | ||
return iter(original_dataset) | ||
|
||
def __len__(self): | ||
return len(original_dataset) | ||
|
||
dataloader = DataLoader(IterableWithLen(), batch_size=16) | ||
assert has_len(dataloader) | ||
assert has_iterable_dataset(dataloader) | ||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
max_steps=3, | ||
) | ||
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): | ||
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) | ||
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): | ||
trainer.test(model, test_dataloaders=[dataloader]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a quite common case, can't we add a function for this like
This may be an overhead now, but we really need similar things quite often
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is a repeated code here. This is already done in reset_val_dataloader. All we need is just to sum num_sanity_val_steps here once #2917 is fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree with both of you. should we block this PR with 2917 or the other way around? Does it matter which one goes first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest block this one. Once I get some answers there I asked, I'll fix that one tonight and then we can complete this one :)