From 54e569b9b95711cf01e00432927103465eeefc4c Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Tue, 4 May 2021 09:36:10 +0200 Subject: [PATCH] Fix/data fetcher (#260) * Add files via upload * exposes ImageClassificationPreprocess to public api * add crashing test * fix fetcher issue * fix linter --- flash/data/data_module.py | 4 ++-- tests/data/test_callbacks.py | 18 ++++++++++++------ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index bcb3787268..874bcd8132 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -170,6 +170,8 @@ def _show_batch(self, stage: RunningStage, func_names: Union[str, List[str]], re iter_dataloader = getattr(self, iter_name) with self.data_fetcher.enable(): + if reset: + self.data_fetcher.batches[stage] = {} try: _ = next(iter_dataloader) except StopIteration: @@ -177,8 +179,6 @@ def _show_batch(self, stage: RunningStage, func_names: Union[str, List[str]], re _ = next(iter_dataloader) data_fetcher: BaseVisualization = self.data_fetcher data_fetcher._show(stage, func_names) - if reset: - self.viz.batches[stage] = {} def show_train_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None: """This function is used to visualize a batch from the train dataloader.""" diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index 468730b099..46a9347cfa 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -131,6 +131,8 @@ class CustomImageClassificationData(ImageClassificationData): def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: return CustomBaseVisualization(*args, **kwargs) + B: int = 2 # batch_size + dm = CustomImageClassificationData.from_filepaths( train_filepaths=train_images, train_labels=[0, 1], @@ -139,16 +141,18 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: test_filepaths=train_images, test_labels=[4, 5], predict_filepaths=train_images, - batch_size=2, + batch_size=B, num_workers=0, ) + num_tests = 10 + for stage in _STAGES_PREFIX.values(): - for _ in range(10): + for _ in range(num_tests): for fcn_name in _PREPROCESS_FUNCS: fcn = getattr(dm, f"show_{stage}_batch") - fcn(fcn_name, reset=False) + fcn(fcn_name, reset=True) is_predict = stage == "predict" @@ -161,6 +165,7 @@ def _get_result(function_name: str): return dm.data_fetcher.batches[stage][function_name] res = _get_result("load_sample") + assert len(res) == B assert isinstance(_extract_data(res), Image.Image) if not is_predict: @@ -168,6 +173,7 @@ def _get_result(function_name: str): assert isinstance(res[0][1], torch.Tensor) res = _get_result("to_tensor_transform") + assert len(res) == B assert isinstance(_extract_data(res), torch.Tensor) if not is_predict: @@ -175,18 +181,18 @@ def _get_result(function_name: str): assert isinstance(res[0][1], torch.Tensor) res = _get_result("collate") - assert _extract_data(res).shape == (2, 3, 196, 196) + assert _extract_data(res).shape == (B, 3, 196, 196) if not is_predict: res = _get_result("collate") assert res[0][1].shape == torch.Size([2]) res = _get_result("per_batch_transform") - assert _extract_data(res).shape == (2, 3, 196, 196) + assert _extract_data(res).shape == (B, 3, 196, 196) if not is_predict: res = _get_result("per_batch_transform") - assert res[0][1].shape == (2, ) + assert res[0][1].shape == (B, ) assert dm.data_fetcher.show_load_sample_called assert dm.data_fetcher.show_pre_tensor_transform_called