From 7111e82105d3df385db75b197dfc2783431b1fe9 Mon Sep 17 00:00:00 2001 From: fstroth Date: Mon, 21 Jun 2021 21:39:51 +0200 Subject: [PATCH] =?UTF-8?q?Fixes=20predict=20batches=20should=20now=20be?= =?UTF-8?q?=20shown=20for=20image=20class=E2=80=A6=20(#438)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * (FIX) Fixes #430, predict batches should now be shown for image classification. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update CHANGELOG.md Co-authored-by: frederik Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ethan Harris Co-authored-by: Ethan Harris --- CHANGELOG.md | 1 + flash/image/classification/data.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 22dedc2b06..28add0c6a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where the `DefaultDataKeys.METADATA` couldn't be a dict ([#393](https://github.com/PyTorchLightning/lightning-flash/pull/393)) - Fixed a bug where the `SemanticSegmentation` task would not work as expected with finetuning callbacks ([#412](https://github.com/PyTorchLightning/lightning-flash/pull/412)) +- Fixed a bug where predict batches could not be visualized with `ImageClassificationData` ([#438](https://github.com/PyTorchLightning/lightning-flash/pull/438)) ## [0.3.2] - 2021-06-08 diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index ea029d3c0b..deb84f82a4 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -134,9 +134,9 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) for i, ax in enumerate(axs.ravel()): # unpack images and labels if isinstance(data, list): - _img, _label = data[i][DefaultDataKeys.INPUT], data[i][DefaultDataKeys.TARGET] + _img, _label = data[i][DefaultDataKeys.INPUT], data[i].get(DefaultDataKeys.TARGET, "") elif isinstance(data, dict): - _img, _label = data[DefaultDataKeys.INPUT][i], data[DefaultDataKeys.TARGET][i] + _img, _label = data[DefaultDataKeys.INPUT][i], data.get(DefaultDataKeys.TARGET, [""] * (i + 1))[i] else: raise TypeError(f"Unknown data type. Got: {type(data)}.") # convert images to numpy