diff --git a/CHANGELOG.md b/CHANGELOG.md index 22dedc2b06..28add0c6a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where the `DefaultDataKeys.METADATA` couldn't be a dict ([#393](https://github.com/PyTorchLightning/lightning-flash/pull/393)) - Fixed a bug where the `SemanticSegmentation` task would not work as expected with finetuning callbacks ([#412](https://github.com/PyTorchLightning/lightning-flash/pull/412)) +- Fixed a bug where predict batches could not be visualized with `ImageClassificationData` ([#438](https://github.com/PyTorchLightning/lightning-flash/pull/438)) ## [0.3.2] - 2021-06-08 diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index ea029d3c0b..deb84f82a4 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -134,9 +134,9 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) for i, ax in enumerate(axs.ravel()): # unpack images and labels if isinstance(data, list): - _img, _label = data[i][DefaultDataKeys.INPUT], data[i][DefaultDataKeys.TARGET] + _img, _label = data[i][DefaultDataKeys.INPUT], data[i].get(DefaultDataKeys.TARGET, "") elif isinstance(data, dict): - _img, _label = data[DefaultDataKeys.INPUT][i], data[DefaultDataKeys.TARGET][i] + _img, _label = data[DefaultDataKeys.INPUT][i], data.get(DefaultDataKeys.TARGET, [""] * (i + 1))[i] else: raise TypeError(f"Unknown data type. Got: {type(data)}.") # convert images to numpy