Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix/data fetcher (#260)
Browse files Browse the repository at this point in the history
* Add files via upload

* exposes ImageClassificationPreprocess to public api

* add crashing test

* fix fetcher issue

* fix linter
  • Loading branch information
edgarriba authored May 4, 2021
1 parent 8bf5d76 commit 54e569b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
4 changes: 2 additions & 2 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,15 @@ def _show_batch(self, stage: RunningStage, func_names: Union[str, List[str]], re

iter_dataloader = getattr(self, iter_name)
with self.data_fetcher.enable():
if reset:
self.data_fetcher.batches[stage] = {}
try:
_ = next(iter_dataloader)
except StopIteration:
iter_dataloader = self._reset_iterator(stage)
_ = next(iter_dataloader)
data_fetcher: BaseVisualization = self.data_fetcher
data_fetcher._show(stage, func_names)
if reset:
self.viz.batches[stage] = {}

def show_train_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None:
"""This function is used to visualize a batch from the train dataloader."""
Expand Down
18 changes: 12 additions & 6 deletions tests/data/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ class CustomImageClassificationData(ImageClassificationData):
def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization:
return CustomBaseVisualization(*args, **kwargs)

B: int = 2 # batch_size

dm = CustomImageClassificationData.from_filepaths(
train_filepaths=train_images,
train_labels=[0, 1],
Expand All @@ -139,16 +141,18 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization:
test_filepaths=train_images,
test_labels=[4, 5],
predict_filepaths=train_images,
batch_size=2,
batch_size=B,
num_workers=0,
)

num_tests = 10

for stage in _STAGES_PREFIX.values():

for _ in range(10):
for _ in range(num_tests):
for fcn_name in _PREPROCESS_FUNCS:
fcn = getattr(dm, f"show_{stage}_batch")
fcn(fcn_name, reset=False)
fcn(fcn_name, reset=True)

is_predict = stage == "predict"

Expand All @@ -161,32 +165,34 @@ def _get_result(function_name: str):
return dm.data_fetcher.batches[stage][function_name]

res = _get_result("load_sample")
assert len(res) == B
assert isinstance(_extract_data(res), Image.Image)

if not is_predict:
res = _get_result("load_sample")
assert isinstance(res[0][1], torch.Tensor)

res = _get_result("to_tensor_transform")
assert len(res) == B
assert isinstance(_extract_data(res), torch.Tensor)

if not is_predict:
res = _get_result("to_tensor_transform")
assert isinstance(res[0][1], torch.Tensor)

res = _get_result("collate")
assert _extract_data(res).shape == (2, 3, 196, 196)
assert _extract_data(res).shape == (B, 3, 196, 196)

if not is_predict:
res = _get_result("collate")
assert res[0][1].shape == torch.Size([2])

res = _get_result("per_batch_transform")
assert _extract_data(res).shape == (2, 3, 196, 196)
assert _extract_data(res).shape == (B, 3, 196, 196)

if not is_predict:
res = _get_result("per_batch_transform")
assert res[0][1].shape == (2, )
assert res[0][1].shape == (B, )

assert dm.data_fetcher.show_load_sample_called
assert dm.data_fetcher.show_pre_tensor_transform_called
Expand Down

0 comments on commit 54e569b

Please sign in to comment.