From cd3faf8ec4a4c0e7bc54654a92432f441e103cb6 Mon Sep 17 00:00:00 2001 From: frederik Date: Mon, 21 Jun 2021 15:25:18 +0200 Subject: [PATCH 1/3] (FIX) Fixes #430, predict batches should now be shown for image classification. --- flash/image/classification/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 412ee17c03..f8463cbaab 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -136,9 +136,9 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) for i, ax in enumerate(axs.ravel()): # unpack images and labels if isinstance(data, list): - _img, _label = data[i][DefaultDataKeys.INPUT], data[i][DefaultDataKeys.TARGET] + _img, _label = data[i][DefaultDataKeys.INPUT], data[i].get(DefaultDataKeys.TARGET, "") elif isinstance(data, dict): - _img, _label = data[DefaultDataKeys.INPUT][i], data[DefaultDataKeys.TARGET][i] + _img, _label = data[DefaultDataKeys.INPUT][i], data.get(DefaultDataKeys.TARGET, [""]*(i+1))[i] else: raise TypeError(f"Unknown data type. Got: {type(data)}.") # convert images to numpy From 543ae92e42ec3edecaa9fa9d5a378799587bc654 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Jun 2021 13:47:24 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/image/classification/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index f8463cbaab..08095c9893 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -138,7 +138,7 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) if isinstance(data, list): _img, _label = data[i][DefaultDataKeys.INPUT], data[i].get(DefaultDataKeys.TARGET, "") elif isinstance(data, dict): - _img, _label = data[DefaultDataKeys.INPUT][i], data.get(DefaultDataKeys.TARGET, [""]*(i+1))[i] + _img, _label = data[DefaultDataKeys.INPUT][i], data.get(DefaultDataKeys.TARGET, [""] * (i + 1))[i] else: raise TypeError(f"Unknown data type. Got: {type(data)}.") # convert images to numpy From d26ef674d6c6f11f22dd3e0281d7335891acb61f Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 21 Jun 2021 20:30:01 +0100 Subject: [PATCH 3/3] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 22dedc2b06..28add0c6a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where the `DefaultDataKeys.METADATA` couldn't be a dict ([#393](https://github.com/PyTorchLightning/lightning-flash/pull/393)) - Fixed a bug where the `SemanticSegmentation` task would not work as expected with finetuning callbacks ([#412](https://github.com/PyTorchLightning/lightning-flash/pull/412)) +- Fixed a bug where predict batches could not be visualized with `ImageClassificationData` ([#438](https://github.com/PyTorchLightning/lightning-flash/pull/438)) ## [0.3.2] - 2021-06-08