Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fixes predict batches should now be shown for image class… #438

Merged
merged 5 commits into from
Jun 21, 2021
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
@@ -136,9 +136,9 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str)
for i, ax in enumerate(axs.ravel()):
# unpack images and labels
if isinstance(data, list):
_img, _label = data[i][DefaultDataKeys.INPUT], data[i][DefaultDataKeys.TARGET]
_img, _label = data[i][DefaultDataKeys.INPUT], data[i].get(DefaultDataKeys.TARGET, "")
elif isinstance(data, dict):
_img, _label = data[DefaultDataKeys.INPUT][i], data[DefaultDataKeys.TARGET][i]
_img, _label = data[DefaultDataKeys.INPUT][i], data.get(DefaultDataKeys.TARGET, [""] * (i + 1))[i]
else:
raise TypeError(f"Unknown data type. Got: {type(data)}.")
# convert images to numpy