Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

show_predict_batch fails due to no targets being available #430

Closed
fstroth opened this issue Jun 19, 2021 · 1 comment · Fixed by #438
Closed

show_predict_batch fails due to no targets being available #430

fstroth opened this issue Jun 19, 2021 · 1 comment · Fixed by #438
Labels
bug / fix Something isn't working help wanted Extra attention is needed

Comments

@fstroth
Copy link
Contributor

fstroth commented Jun 19, 2021

🐛 Bug

When calling the show_predict_batch method for image classification data an error is thrown due to not targets being available.

for i, ax in enumerate(axs.ravel()):
    # unpack images and labels
    if isinstance(data, list):
        _img, _label = data[i][DefaultDataKeys.INPUT], data[i][DefaultDataKeys.TARGET]
    elif isinstance(data, dict):
        _img, _label = data[DefaultDataKeys.INPUT][i], data[DefaultDataKeys.TARGET][i]
    else:
        raise TypeError(f"Unknown data type. Got: {type(data)}.")
    # convert images to numpy
    _img: np.ndarray = self._to_numpy(_img)
    if isinstance(_label, torch.Tensor):
        _label = _label.squeeze().tolist()
    # show image and set label as subplot title
    ax.imshow(_img)
    ax.set_title(str(_label))
    ax.axis('off')

The fix should be simple:

for i, ax in enumerate(axs.ravel()):
    # unpack images and labels
    if isinstance(data, list):
        # use the get method to return an empty string if no targets are available
        _img, _label = data[i][DefaultDataKeys.INPUT], data[i].get([DefaultDataKeys.TARGET], "")
    elif isinstance(data, dict):
        # use the get method to return a list that contains an empty string if no targets are available
        _img, _label = data[DefaultDataKeys.INPUT][i], data.get([DefaultDataKeys.TARGET], [""])[i]
    else:
        raise TypeError(f"Unknown data type. Got: {type(data)}.")
    # convert images to numpy
    _img: np.ndarray = self._to_numpy(_img)
    if isinstance(_label, torch.Tensor):
        _label = _label.squeeze().tolist()
    # show image and set label as subplot title
    ax.imshow(_img)
    ax.set_title(str(_label))
    ax.axis('off')

I can create a PR later, when I have time.

To Reproduce

Just have flash installed.

Code sample

from flash.core.data.utils import download_data
from flash.image import ImageClassificationData

download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
    predict_folder="data/hymenoptera_data/predict/"
)

datamodule.show_predict_batch()

This will give the following error message:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-4-ff8f29471c71> in <module>
----> 1 datamodule.show_predict_batch()

~/anaconda3/lib/python3.8/site-packages/flash/core/data/data_module.py in show_predict_batch(self, hooks_names, reset)
    225         """This function is used to visualize a batch from the predict dataloader."""
    226         stage_name: str = _STAGES_PREFIX[RunningStage.PREDICTING]
--> 227         self._show_batch(stage_name, hooks_names, reset=reset)
    228 
    229     @staticmethod

~/anaconda3/lib/python3.8/site-packages/flash/core/data/data_module.py in _show_batch(self, stage, func_names, reset)
    203                 _ = next(iter_dataloader)
    204             data_fetcher: BaseVisualization = self.data_fetcher
--> 205             data_fetcher._show(stage, func_names)
    206             if reset:
    207                 self.data_fetcher.batches[stage] = {}

~/anaconda3/lib/python3.8/site-packages/flash/core/data/base_viz.py in _show(self, running_stage, func_names_list)
    110 
    111     def _show(self, running_stage: RunningStage, func_names_list: List[str]) -> None:
--> 112         self.show(self.batches[running_stage], running_stage, func_names_list)
    113 
    114     def show(self, batch: Dict[str, Any], running_stage: RunningStage, func_names_list: List[str]) -> None:

~/anaconda3/lib/python3.8/site-packages/flash/core/data/base_viz.py in show(self, batch, running_stage, func_names_list)
    124             hook_name = f"show_{func_name}"
    125             if _is_overriden(hook_name, self, BaseVisualization):
--> 126                 getattr(self, hook_name)(batch[func_name], running_stage)
    127 
    128     def show_load_sample(self, samples: List[Any], running_stage: RunningStage):

~/anaconda3/lib/python3.8/site-packages/flash/image/classification/data.py in show_load_sample(self, samples, running_stage)
    144     def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
    145         win_title: str = f"{running_stage} - show_load_sample"
--> 146         self._show_images_and_labels(samples, len(samples), win_title)
    147 
    148     def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage):

~/anaconda3/lib/python3.8/site-packages/flash/image/classification/data.py in _show_images_and_labels(self, data, num_samples, title)
    127             # unpack images and labels
    128             if isinstance(data, list):
--> 129                 _img, _label = data[i][DefaultDataKeys.INPUT], data[i][DefaultDataKeys.TARGET]
    130             elif isinstance(data, dict):
    131                 _img, _label = data[DefaultDataKeys.INPUT][i], data[DefaultDataKeys.TARGET][i]

KeyError: <DefaultDataKeys.TARGET: 'target'>

Expected behavior

The batch should be shown without labels.

@fstroth fstroth added bug / fix Something isn't working help wanted Extra attention is needed labels Jun 19, 2021
@ethanwharris
Copy link
Collaborator

@fstroth thanks for the bug report! Your fix looks good, happy to have a PR 😃

fstroth pushed a commit to fstroth/lightning-flash that referenced this issue Jun 21, 2021
ethanwharris added a commit that referenced this issue Jun 21, 2021
* (FIX) Fixes #430, predict batches should now be shown for image classification.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update CHANGELOG.md

Co-authored-by: frederik <fstroth>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants