From fcf716eb49166b38ab71ab10e13b23cc4454ee98 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 24 Nov 2021 03:14:09 +0100 Subject: [PATCH 1/6] Improve error message on `TypeError` during `DataLoader` reconstruction --- pytorch_lightning/utilities/data.py | 21 +++++++++++++++++- tests/utilities/test_data.py | 33 +++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 5b56940460ca4..65ee49a6cff22 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -177,7 +177,26 @@ def get_len(dataloader: DataLoader) -> Union[int, float]: def _update_dataloader(dataloader: DataLoader, sampler: Sampler, mode: Optional[RunningStage] = None) -> DataLoader: dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode) dl_cls = type(dataloader) - dataloader = dl_cls(**dl_kwargs) + try: + dataloader = dl_cls(**dl_kwargs) + except TypeError as e: + # improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass + # `__init__` arguments map to one `DataLoader.__init__` argument + import re + + pattern = re.compile(r".*__init__\(\) got multiple values .* '(\w+)'") + match = re.match(pattern, str(e)) + if not match: + # an unexpected `TypeError`, continue failure + raise + argument = match.groups()[0] + message = ( + f"The {dl_cls.__name__} `DataLoader` implementation has an error where more than one `__init__` argument" + f" can be passed to to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing" + f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`." + f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key." + ) + raise MisconfigurationException(message) from e return dataloader diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index ae1f8c6505efc..9b9d3543d1bfb 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -4,6 +4,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.utilities.data import ( + _update_dataloader, extract_batch_size, get_len, has_iterable_dataset, @@ -112,3 +113,35 @@ def test_has_len_all_rank(): assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.training_type_plugin, model) assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model) + + +def test_update_dataloader_typerror_custom_exception(): + class BadImpl(DataLoader): + def __init__(self, foo, *args, **kwargs): + # positional conflict with `dataset` + self.foo = foo + super().__init__(foo, *args, **kwargs) + + dataloader = BadImpl([1, 2, 3]) + with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"): + _update_dataloader(dataloader, dataloader.sampler) + + class BadImpl2(DataLoader): + def __init__(self, randomize, *args, **kwargs): + # keyword conflict with `shuffle` + self.randomize = randomize + super().__init__(*args, shuffle=randomize, **kwargs) + + dataloader = BadImpl2(False, []) + with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`shuffle`"): + _update_dataloader(dataloader, dataloader.sampler) + + class GoodImpl(DataLoader): + def __init__(self, randomize, *args, **kwargs): + # fixed implementation + self.randomize = randomize or kwargs.pop("shuffle", False) + super().__init__(*args, shuffle=randomize, **kwargs) + + dataloader = GoodImpl(False, []) + new_dataloader = _update_dataloader(dataloader, dataloader.sampler) + assert isinstance(new_dataloader, GoodImpl) From 716823c0656abcd9ac89ddfcf0f45573c5dd7e3b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 24 Nov 2021 03:19:27 +0100 Subject: [PATCH 2/6] Update CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f7014a71d9a3..d66f06c92b4b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,7 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - -- +- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719)) - From c9064adf4861b60a3d9f29c0451fbaa27c8cd716 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 24 Nov 2021 03:21:47 +0100 Subject: [PATCH 3/6] Remove pattern saving --- pytorch_lightning/utilities/data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 65ee49a6cff22..c273cf41fcc7f 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -184,8 +184,7 @@ def _update_dataloader(dataloader: DataLoader, sampler: Sampler, mode: Optional[ # `__init__` arguments map to one `DataLoader.__init__` argument import re - pattern = re.compile(r".*__init__\(\) got multiple values .* '(\w+)'") - match = re.match(pattern, str(e)) + match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e)) if not match: # an unexpected `TypeError`, continue failure raise From 67131254827f6228cb997e56e1406adfd03a9082 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 24 Nov 2021 03:22:16 +0100 Subject: [PATCH 4/6] Improve comments --- tests/utilities/test_data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index 9b9d3543d1bfb..967c5a176cef1 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -118,8 +118,8 @@ def test_has_len_all_rank(): def test_update_dataloader_typerror_custom_exception(): class BadImpl(DataLoader): def __init__(self, foo, *args, **kwargs): - # positional conflict with `dataset` self.foo = foo + # positional conflict with `dataset` super().__init__(foo, *args, **kwargs) dataloader = BadImpl([1, 2, 3]) @@ -128,8 +128,8 @@ def __init__(self, foo, *args, **kwargs): class BadImpl2(DataLoader): def __init__(self, randomize, *args, **kwargs): - # keyword conflict with `shuffle` self.randomize = randomize + # keyword conflict with `shuffle` super().__init__(*args, shuffle=randomize, **kwargs) dataloader = BadImpl2(False, []) @@ -138,7 +138,7 @@ def __init__(self, randomize, *args, **kwargs): class GoodImpl(DataLoader): def __init__(self, randomize, *args, **kwargs): - # fixed implementation + # fixed implementation, kwargs are filtered self.randomize = randomize or kwargs.pop("shuffle", False) super().__init__(*args, shuffle=randomize, **kwargs) From 870e1fd7bfd953920382af4fa825430d3af0cd5b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Nov 2021 15:14:28 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utilities/test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index 20ce30874917e..e202941cf0fbb 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -4,8 +4,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.utilities.data import ( - _update_dataloader, _replace_dataloader_init_method, + _update_dataloader, extract_batch_size, get_len, has_iterable_dataset, From 4b9699ad3f00aee89635e5f24646bdadd2f8e97f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 24 Nov 2021 16:58:11 +0100 Subject: [PATCH 6/6] Update pytorch_lightning/utilities/data.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/utilities/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index cf58e81e8fecc..9963bf6c85ffd 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -194,7 +194,7 @@ def _update_dataloader(dataloader: DataLoader, sampler: Sampler, mode: Optional[ argument = match.groups()[0] message = ( f"The {dl_cls.__name__} `DataLoader` implementation has an error where more than one `__init__` argument" - f" can be passed to to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing" + f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing" f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`." f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key." )