Skip to content
27 changes: 15 additions & 12 deletions examples/visualization/plot_image_gallery.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""
Title: Plot an image gallery
Author: [lukewood](https://lukewood.xyz)
Author: [lukewood](https://lukewood.xyz), updated by
[Suvaditya Mukherjee](https://twitter.com/halcyonrayes)
Date created: 2022/10/16
Last modified: 2022/10/16
Last modified: 2022/06/24
Description: Visualize ground truth and predicted bounding boxes for a given
dataset.
"""
Expand All @@ -11,7 +12,7 @@
Plotting images from a TensorFlow dataset is easy with KerasCV. Behold:
"""

import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds

import keras_cv
Expand All @@ -23,18 +24,20 @@
shuffle_files=True,
)

keras_cv.visualization.plot_image_gallery(
train_ds,
value_range=(0, 255),
scale=3,
)

def unpackage_tfds_inputs(inputs):
return inputs["image"]
"""
If you want to use plain NumPy arrays, you can do that too:
"""

# Prepare some sample NumPy arrays from random noise

train_ds = train_ds.map(unpackage_tfds_inputs)
train_ds = train_ds.apply(tf.data.experimental.dense_to_ragged_batch(16))
samples = np.random.randint(0, 255, (20, 224, 224, 3))

keras_cv.visualization.plot_image_gallery(
next(iter(train_ds.take(1))),
value_range=(0, 255),
scale=3,
rows=2,
cols=2,
samples, value_range=(0, 255), scale=3, rows=4, cols=5
)
114 changes: 88 additions & 26 deletions keras_cv/visualization/plot_image_gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import numpy as np
import tensorflow as tf

import keras_cv
from keras_cv import utils
from keras_cv.utils import assert_matplotlib_installed
Expand All @@ -22,12 +27,43 @@
plt = None


def _extract_image_batch(images, num_images, batch_size):
def unpack_images(inputs):
return inputs["image"]

num_batches_required = math.ceil(num_images / batch_size)

if isinstance(images, tf.data.Dataset):
images = images.map(unpack_images)

if batch_size == 1:
images = images.ragged_batch(num_batches_required)
sample = next(iter(images.take(1)))
else:
sample = next(iter(images.take(num_batches_required)))

return sample

else:
if len(images.shape) != 4:
raise ValueError(
"`plot_images_gallery()` requires you to "
"batch your `np.array` samples together."
)
else:
num_samples = (
num_images if num_images <= batch_size else num_batches_required
)
sample = images[:num_samples, ...]
return sample


def plot_image_gallery(
images,
value_range,
rows=3,
cols=3,
scale=2,
rows=None,
cols=None,
path=None,
show=None,
transparent=True,
Expand All @@ -45,32 +81,27 @@ def plot_image_gallery(
shuffle_files=True,
)


def unpackage_tfds_inputs(inputs):
return inputs["image"]

train_ds = train_ds.map(unpackage_tfds_inputs)
train_ds = train_ds.apply(tf.data.experimental.dense_to_ragged_batch(16))

keras_cv.visualization.plot_image_gallery(
next(iter(train_ds.take(1))),
train_ds,
value_range=(0, 255),
scale=3,
rows=2,
cols=2,
)
```

![example gallery](https://i.imgur.com/r0ndse0.png)

Args:
images: a Tensor or NumPy array containing images to show in the
gallery.
images: a Tensor, `tf.data.Dataset` or NumPy array containing images
to show in the gallery. Note: If using a `tf.data.Dataset`,
images should be present in the `FeaturesDict` under
the key `image`.
value_range: value range of the images. Common examples include
`(0, 255)` and `(0, 1)`.
rows: number of rows in the gallery to show.
cols: number of columns in the gallery to show.
scale: how large to scale the images in the gallery
rows: (Optional) number of rows in the gallery to show.
Required if inputs are unbatched.
cols: (Optional) number of columns in the gallery to show.
Required if inputs are unbatched.
path: (Optional) path to save the resulting gallery to.
show: (Optional) whether to show the gallery of images.
transparent: (Optional) whether to give the image a transparent
Expand All @@ -80,6 +111,7 @@ def unpackage_tfds_inputs(inputs):
legend_handles: (Optional) matplotlib.patches List of legend handles.
I.e. passing: `[patches.Patch(color='red', label='mylabel')]` will
produce a legend with a single red patch and the label 'mylabel'.

"""
assert_matplotlib_installed("plot_bounding_box_gallery")

Expand All @@ -92,27 +124,57 @@ def unpackage_tfds_inputs(inputs):
"to be true."
)

fig = plt.figure(figsize=(cols * scale, rows * scale))
fig.tight_layout() # Or equivalently, "plt.tight_layout()"
plt.subplots_adjust(wspace=0, hspace=0)
plt.margins(x=0, y=0)
plt.axis("off")
if isinstance(images, tf.data.Dataset):
sample = next(iter(images.take(1)))
batch_size = (
sample["image"].shape[0] if len(sample["image"].shape) == 4 else 1
) # batch_size from within passed `tf.data.Dataset`
else:
batch_size = (
images.shape[0] if len(images.shape) == 4 else 1
) # batch_size from np.array or single image

rows = rows or int(math.ceil(math.sqrt(batch_size)))
cols = cols or int(math.ceil(batch_size // rows))

num_images = rows * cols
images = _extract_image_batch(images, num_images, batch_size)

# Generate subplots
fig, axes = plt.subplots(
nrows=rows,
ncols=cols,
figsize=(cols * scale, rows * scale),
frameon=False,
layout="tight",
squeeze=True,
sharex="row",
sharey="col",
)
fig.subplots_adjust(wspace=0, hspace=0)

if isinstance(axes, np.ndarray) and len(axes.shape) == 1:
expand_axis = 0 if rows == 1 else -1
axes = np.expand_dims(axes, expand_axis)

if legend_handles is not None:
fig.legend(handles=legend_handles, loc="lower center")

# Perform image range transform
images = keras_cv.utils.transform_value_range(
images, original_range=value_range, target_range=(0, 255)
)
images = utils.to_numpy(images)
images = images.astype(int)

for row in range(rows):
for col in range(cols):
index = row * cols + col
plt.subplot(rows, cols, index + 1)
plt.imshow(images[index].astype("uint8"))
plt.axis("off")
plt.margins(x=0, y=0)
current_axis = (
axes[row, col] if isinstance(axes, np.ndarray) else axes
)
current_axis.imshow(images[index].astype("uint8"))
current_axis.margins(x=0, y=0)
current_axis.axis("off")

if path is None and not show:
return
Expand Down