From 1f6abec09b2d3ecf5ddddfc0f9d6188285c6d362 Mon Sep 17 00:00:00 2001 From: Suvaditya Mukherjee Date: Tue, 30 May 2023 12:52:25 +0530 Subject: [PATCH 1/9] feat: add support for tf.data.dataset removed hard dependency on having rows and cols params, added support for processing tf.data.Dataset internally Signed-off-by: Suvaditya Mukherjee --- keras_cv/visualization/plot_image_gallery.py | 86 ++++++++++++++------ 1 file changed, 63 insertions(+), 23 deletions(-) diff --git a/keras_cv/visualization/plot_image_gallery.py b/keras_cv/visualization/plot_image_gallery.py index bb9bb062be..38d320c660 100644 --- a/keras_cv/visualization/plot_image_gallery.py +++ b/keras_cv/visualization/plot_image_gallery.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import keras_cv +import tensorflow as tf from keras_cv import utils from keras_cv.utils import assert_matplotlib_installed @@ -25,14 +27,16 @@ def plot_image_gallery( images, value_range, - rows=3, - cols=3, + rows=None, + cols=None, + batch_size=8, scale=2, path=None, show=None, transparent=True, dpi=60, legend_handles=None, + image_key="image", ): """Displays a gallery of images. @@ -45,19 +49,10 @@ def plot_image_gallery( shuffle_files=True, ) - - def unpackage_tfds_inputs(inputs): - return inputs["image"] - - train_ds = train_ds.map(unpackage_tfds_inputs) - train_ds = train_ds.apply(tf.data.experimental.dense_to_ragged_batch(16)) - keras_cv.visualization.plot_image_gallery( - next(iter(train_ds.take(1))), + train_ds, value_range=(0, 255), scale=3, - rows=2, - cols=2, ) ``` @@ -68,9 +63,9 @@ def unpackage_tfds_inputs(inputs): gallery. value_range: value range of the images. Common examples include `(0, 255)` and `(0, 1)`. - rows: number of rows in the gallery to show. - cols: number of columns in the gallery to show. scale: how large to scale the images in the gallery + rows: (Optional) number of rows in the gallery to show. + cols: (Optional) number of columns in the gallery to show. path: (Optional) path to save the resulting gallery to. show: (Optional) whether to show the gallery of images. transparent: (Optional) whether to give the image a transparent @@ -80,6 +75,9 @@ def unpackage_tfds_inputs(inputs): legend_handles: (Optional) matplotlib.patches List of legend handles. I.e. passing: `[patches.Patch(color='red', label='mylabel')]` will produce a legend with a single red patch and the label 'mylabel'. + image_key: (Optional) Key of the argument holding the image. Only + required when using a `tf.data.Dataset` instance. Defaults to + "image". """ assert_matplotlib_installed("plot_bounding_box_gallery") @@ -92,27 +90,69 @@ def unpackage_tfds_inputs(inputs): "to be true." ) - fig = plt.figure(figsize=(cols * scale, rows * scale)) - fig.tight_layout() # Or equivalently, "plt.tight_layout()" - plt.subplots_adjust(wspace=0, hspace=0) - plt.margins(x=0, y=0) - plt.axis("off") + if isinstance(images, tf.data.Dataset): + + # Find final dataset batch size + sample = next(iter(images.take(1))) + if len(sample[image_key].shape) == 3: + default_dataset_batch_size = 8 + images = images.ragged_batch(batch_size=default_dataset_batch_size) + elif len(sample[image_key].shape) == 4: + default_dataset_batch_size = sample[image_key].shape[0] + else: + raise ValueError( + "plot_image_gallery() expects `tf.data.Dataset` to have TensorShape with length 3 or 4." + ) + + batches = default_dataset_batch_size + + def unpack_images(inputs): + return inputs[image_key] + + images = images.map(unpack_images) + images = images.take(batches) + images = next(iter(images)) + + # Calculate appropriate number of rows and columns + if rows is None and cols is None: + total_plots = batch_size + cols = batch_size // 2 + + rows = total_plots // cols + + if total_plots % cols != 0: + rows += 1 + + # Generate subplots + fig, axes = plt.subplots( + nrows=rows, + ncols=cols, + figsize=(cols * scale, rows * scale), + layout="tight", + squeeze=True, + sharex="row", + sharey="col", + ) + fig.subplots_adjust(wspace=0, hspace=0) if legend_handles is not None: fig.legend(handles=legend_handles, loc="lower center") + # Perform image range transform images = keras_cv.utils.transform_value_range( images, original_range=value_range, target_range=(0, 255) ) + images = utils.to_numpy(images) images = images.astype(int) + for row in range(rows): for col in range(cols): index = row * cols + col - plt.subplot(rows, cols, index + 1) - plt.imshow(images[index].astype("uint8")) - plt.axis("off") - plt.margins(x=0, y=0) + current_axis = axes[row, col] + current_axis.imshow(images[index].astype("uint8")) + current_axis.margins(x=0, y=0) + current_axis.axis("off") if path is None and not show: return From 92a51a4443ecf49dd35284de305d8e3b0e17cf90 Mon Sep 17 00:00:00 2001 From: Suvaditya Mukherjee Date: Tue, 30 May 2023 13:02:19 +0530 Subject: [PATCH 2/9] chore: removed unused numpy import Signed-off-by: Suvaditya Mukherjee --- keras_cv/visualization/plot_image_gallery.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keras_cv/visualization/plot_image_gallery.py b/keras_cv/visualization/plot_image_gallery.py index 38d320c660..4ab0b6172d 100644 --- a/keras_cv/visualization/plot_image_gallery.py +++ b/keras_cv/visualization/plot_image_gallery.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np import keras_cv import tensorflow as tf from keras_cv import utils @@ -27,10 +26,10 @@ def plot_image_gallery( images, value_range, + scale=2, rows=None, cols=None, batch_size=8, - scale=2, path=None, show=None, transparent=True, @@ -66,6 +65,8 @@ def plot_image_gallery( scale: how large to scale the images in the gallery rows: (Optional) number of rows in the gallery to show. cols: (Optional) number of columns in the gallery to show. + batch_size: (Optional) batch size of a given `tf.data.Dataset` instance. + Defaults to 8. Only required when using a `tf.data.Dataset` instance. path: (Optional) path to save the resulting gallery to. show: (Optional) whether to show the gallery of images. transparent: (Optional) whether to give the image a transparent From 3ad5503f722a5ce53ba1fa55c2449709bb00d28e Mon Sep 17 00:00:00 2001 From: Suvaditya Mukherjee Date: Tue, 30 May 2023 13:03:16 +0530 Subject: [PATCH 3/9] chore: added `tf.data.Dataset` to docstring Signed-off-by: Suvaditya Mukherjee --- keras_cv/visualization/plot_image_gallery.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_cv/visualization/plot_image_gallery.py b/keras_cv/visualization/plot_image_gallery.py index 4ab0b6172d..8d285fdab3 100644 --- a/keras_cv/visualization/plot_image_gallery.py +++ b/keras_cv/visualization/plot_image_gallery.py @@ -58,7 +58,7 @@ def plot_image_gallery( ![example gallery](https://i.imgur.com/r0ndse0.png) Args: - images: a Tensor or NumPy array containing images to show in the + images: a Tensor, `tf.data.Dataset` or NumPy array containing images to show in the gallery. value_range: value range of the images. Common examples include `(0, 255)` and `(0, 1)`. From a7062dc253142819a00f66fa48d392fecf39a3b5 Mon Sep 17 00:00:00 2001 From: Suvaditya Mukherjee Date: Wed, 31 May 2023 01:01:58 +0530 Subject: [PATCH 4/9] chore: addressed comments and modified example file Signed-off-by: Suvaditya Mukherjee --- examples/visualization/plot_image_gallery.py | 15 +--- keras_cv/visualization/plot_image_gallery.py | 82 ++++++++++++-------- 2 files changed, 52 insertions(+), 45 deletions(-) diff --git a/examples/visualization/plot_image_gallery.py b/examples/visualization/plot_image_gallery.py index 297cc6d0a8..b85bfe8a79 100644 --- a/examples/visualization/plot_image_gallery.py +++ b/examples/visualization/plot_image_gallery.py @@ -2,7 +2,7 @@ Title: Plot an image gallery Author: [lukewood](https://lukewood.xyz) Date created: 2022/10/16 -Last modified: 2022/10/16 +Last modified: 2022/05/31 Description: Visualize ground truth and predicted bounding boxes for a given dataset. """ @@ -11,7 +11,6 @@ Plotting images from a TensorFlow dataset is easy with KerasCV. Behold: """ -import tensorflow as tf import tensorflow_datasets as tfds import keras_cv @@ -23,18 +22,10 @@ shuffle_files=True, ) - -def unpackage_tfds_inputs(inputs): - return inputs["image"] - - -train_ds = train_ds.map(unpackage_tfds_inputs) -train_ds = train_ds.apply(tf.data.experimental.dense_to_ragged_batch(16)) +train_ds = train_ds.ragged_batch(16) keras_cv.visualization.plot_image_gallery( - next(iter(train_ds.take(1))), + train_ds, value_range=(0, 255), scale=3, - rows=2, - cols=2, ) diff --git a/keras_cv/visualization/plot_image_gallery.py b/keras_cv/visualization/plot_image_gallery.py index 8d285fdab3..4230ec898d 100644 --- a/keras_cv/visualization/plot_image_gallery.py +++ b/keras_cv/visualization/plot_image_gallery.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import keras_cv import tensorflow as tf from keras_cv import utils @@ -29,13 +30,11 @@ def plot_image_gallery( scale=2, rows=None, cols=None, - batch_size=8, path=None, show=None, transparent=True, dpi=60, legend_handles=None, - image_key="image", ): """Displays a gallery of images. @@ -65,8 +64,6 @@ def plot_image_gallery( scale: how large to scale the images in the gallery rows: (Optional) number of rows in the gallery to show. cols: (Optional) number of columns in the gallery to show. - batch_size: (Optional) batch size of a given `tf.data.Dataset` instance. - Defaults to 8. Only required when using a `tf.data.Dataset` instance. path: (Optional) path to save the resulting gallery to. show: (Optional) whether to show the gallery of images. transparent: (Optional) whether to give the image a transparent @@ -76,9 +73,10 @@ def plot_image_gallery( legend_handles: (Optional) matplotlib.patches List of legend handles. I.e. passing: `[patches.Patch(color='red', label='mylabel')]` will produce a legend with a single red patch and the label 'mylabel'. - image_key: (Optional) Key of the argument holding the image. Only - required when using a `tf.data.Dataset` instance. Defaults to - "image". + + Note: + If using a `tf.data.Dataset`, it is important that the images present in + the `FeaturesDict` should have the key `image`. """ assert_matplotlib_installed("plot_bounding_box_gallery") @@ -91,38 +89,58 @@ def plot_image_gallery( "to be true." ) - if isinstance(images, tf.data.Dataset): + def unpack_images(inputs): + return inputs["image"] - # Find final dataset batch size - sample = next(iter(images.take(1))) - if len(sample[image_key].shape) == 3: - default_dataset_batch_size = 8 - images = images.ragged_batch(batch_size=default_dataset_batch_size) - elif len(sample[image_key].shape) == 4: - default_dataset_batch_size = sample[image_key].shape[0] + # Calculate appropriate number of rows and columns + if rows is None and cols is None: + if isinstance(images, tf.data.Dataset): + sample = next(iter(images.take(1))) + sample_shape = sample["image"].shape + if len(sample_shape) == 4: + batch_size = sample_shape[0] + else: + raise ValueError( + "Passed `tf.data.Dataset` does not appear to be batched. Please batch using the `.batch().`" + ) + + images = images.map(unpack_images) + images = images.take(batch_size) + images = next(iter(images)) else: - raise ValueError( - "plot_image_gallery() expects `tf.data.Dataset` to have TensorShape with length 3 or 4." - ) + sample_shape = images.shape + if len(sample_shape) == 4: + batch_size = sample_shape[0] + else: + raise ValueError( + f"Passed '`{type(images)}`' does not appear to be batched. Please batch using the `.batch()." + ) - batches = default_dataset_batch_size + elif rows is not None and cols is not None: + if isinstance(images, tf.data.Dataset): + batch_size = rows * cols - def unpack_images(inputs): - return inputs[image_key] + sample = next(iter(images.take(1))) + sample_shape = sample["image"].shape - images = images.map(unpack_images) - images = images.take(batches) - images = next(iter(images)) + if len(sample_shape) == 4: + images = images.unbatch() - # Calculate appropriate number of rows and columns - if rows is None and cols is None: - total_plots = batch_size - cols = batch_size // 2 + images = images.ragged_batch(batch_size=batch_size) - rows = total_plots // cols + images = images.map(unpack_images) + images = images.take(batch_size) + images = next(iter(images)) + else: + batch_size = rows * cols + images = images[:batch_size, ...] + else: + raise ValueError( + "plot_image_gallery() expects `tf.data.Dataset` to be batched if rows and cols are not specified." + ) - if total_plots % cols != 0: - rows += 1 + rows = int(math.ceil(batch_size**0.5)) + cols = int(math.ceil(batch_size // rows)) # Generate subplots fig, axes = plt.subplots( @@ -143,9 +161,7 @@ def unpack_images(inputs): images = keras_cv.utils.transform_value_range( images, original_range=value_range, target_range=(0, 255) ) - images = utils.to_numpy(images) - images = images.astype(int) for row in range(rows): for col in range(cols): From 3e069438813dbfaaac7fccde56d9a95324e0e921 Mon Sep 17 00:00:00 2001 From: Suvaditya Mukherjee Date: Wed, 31 May 2023 13:17:11 +0530 Subject: [PATCH 5/9] chore: addressed newer comments updated error messages, added example with nparrays, corrected conditions Signed-off-by: Suvaditya Mukherjee --- examples/visualization/plot_image_gallery.py | 18 +++++++++++++++++- keras_cv/visualization/plot_image_gallery.py | 11 ++++++----- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/examples/visualization/plot_image_gallery.py b/examples/visualization/plot_image_gallery.py index b85bfe8a79..28ea935909 100644 --- a/examples/visualization/plot_image_gallery.py +++ b/examples/visualization/plot_image_gallery.py @@ -12,7 +12,7 @@ """ import tensorflow_datasets as tfds - +import numpy as np import keras_cv train_ds = tfds.load( @@ -29,3 +29,19 @@ value_range=(0, 255), scale=3, ) + +""" +If you want to use plain NumPy arrays, you can do that too: +""" + +# Prepare some NumPy arrays from random noise + +samples = [] +for sample in train_ds.take(20): + samples.append(sample["image"].numpy()) + +samples = np.array(samples, dtype="object") + +keras_cv.visualization.plot_image_gallery( + samples, value_range=(0, 255), scale=3, rows=4, cols=5 +) diff --git a/keras_cv/visualization/plot_image_gallery.py b/keras_cv/visualization/plot_image_gallery.py index 4230ec898d..b0057f88ce 100644 --- a/keras_cv/visualization/plot_image_gallery.py +++ b/keras_cv/visualization/plot_image_gallery.py @@ -62,8 +62,8 @@ def plot_image_gallery( value_range: value range of the images. Common examples include `(0, 255)` and `(0, 1)`. scale: how large to scale the images in the gallery - rows: (Optional) number of rows in the gallery to show. - cols: (Optional) number of columns in the gallery to show. + rows: (Optional) number of rows in the gallery to show. Required if inputs are unbatched. + cols: (Optional) number of columns in the gallery to show. Required if inputs are unbatched. path: (Optional) path to save the resulting gallery to. show: (Optional) whether to show the gallery of images. transparent: (Optional) whether to give the image a transparent @@ -93,7 +93,7 @@ def unpack_images(inputs): return inputs["image"] # Calculate appropriate number of rows and columns - if rows is None and cols is None: + if rows is None or cols is None: if isinstance(images, tf.data.Dataset): sample = next(iter(images.take(1))) sample_shape = sample["image"].shape @@ -113,7 +113,8 @@ def unpack_images(inputs): batch_size = sample_shape[0] else: raise ValueError( - f"Passed '`{type(images)}`' does not appear to be batched. Please batch using the `.batch()." + f"`plot_image_gallery` received unbatched images and `cols` and `rows` " + "were both `None`. Either images should be batched, or `cols` and `rows` should be specified." ) elif rows is not None and cols is not None: @@ -136,7 +137,7 @@ def unpack_images(inputs): images = images[:batch_size, ...] else: raise ValueError( - "plot_image_gallery() expects `tf.data.Dataset` to be batched if rows and cols are not specified." + "plot_image_gallery() expects `tf.data.Dataset` to be batched if rows or cols are not specified." ) rows = int(math.ceil(batch_size**0.5)) From 785aa1f8966b7142b566a78b191440061232d85c Mon Sep 17 00:00:00 2001 From: Suvaditya Mukherjee Date: Sat, 24 Jun 2023 20:44:33 +0530 Subject: [PATCH 6/9] fix: refactored code addressed most comments, have rewritten the code from scratch to find a new and cleaner method to work with the images Signed-off-by: Suvaditya Mukherjee --- examples/visualization/plot_image_gallery.py | 14 +-- keras_cv/visualization/plot_image_gallery.py | 110 +++++++++---------- 2 files changed, 58 insertions(+), 66 deletions(-) diff --git a/examples/visualization/plot_image_gallery.py b/examples/visualization/plot_image_gallery.py index 28ea935909..b1cd824fa6 100644 --- a/examples/visualization/plot_image_gallery.py +++ b/examples/visualization/plot_image_gallery.py @@ -1,8 +1,8 @@ """ Title: Plot an image gallery -Author: [lukewood](https://lukewood.xyz) +Author: [lukewood](https://lukewood.xyz), updated by [Suvaditya Mukherjee](https://twitter.com/halcyonrayes) Date created: 2022/10/16 -Last modified: 2022/05/31 +Last modified: 2022/06/24 Description: Visualize ground truth and predicted bounding boxes for a given dataset. """ @@ -22,8 +22,6 @@ shuffle_files=True, ) -train_ds = train_ds.ragged_batch(16) - keras_cv.visualization.plot_image_gallery( train_ds, value_range=(0, 255), @@ -34,13 +32,9 @@ If you want to use plain NumPy arrays, you can do that too: """ -# Prepare some NumPy arrays from random noise - -samples = [] -for sample in train_ds.take(20): - samples.append(sample["image"].numpy()) +# Prepare some sample NumPy arrays from random noise -samples = np.array(samples, dtype="object") +samples = np.random.randint(0, 255, (20, 224, 224, 3)) keras_cv.visualization.plot_image_gallery( samples, value_range=(0, 255), scale=3, rows=4, cols=5 diff --git a/keras_cv/visualization/plot_image_gallery.py b/keras_cv/visualization/plot_image_gallery.py index b0057f88ce..a05043ca5a 100644 --- a/keras_cv/visualization/plot_image_gallery.py +++ b/keras_cv/visualization/plot_image_gallery.py @@ -14,6 +14,7 @@ import math import keras_cv +import numpy as np import tensorflow as tf from keras_cv import utils from keras_cv.utils import assert_matplotlib_installed @@ -24,6 +25,36 @@ plt = None +def _extract_image_batch(images, num_images, batch_size): + def unpack_images(inputs): + return inputs["image"] + + num_batches_required = math.ceil(num_images / batch_size) + + if isinstance(images, tf.data.Dataset): + images = images.map(unpack_images) + + if batch_size == 1: + images = images.ragged_batch(num_batches_required) + sample = next(iter(images.take(1))) + else: + sample = next(iter(images.take(num_batches_required))) + + return sample + + else: + if len(images.shape) != 4: + raise ValueError( + "`plot_images_gallery()` requires you to batch your `np.array` samples together." + ) + else: + num_samples = ( + num_images if num_images <= batch_size else num_batches_required + ) + sample = images[:num_samples, ...] + return sample + + def plot_image_gallery( images, value_range, @@ -58,7 +89,8 @@ def plot_image_gallery( Args: images: a Tensor, `tf.data.Dataset` or NumPy array containing images to show in the - gallery. + gallery. Note: If using a `tf.data.Dataset`, images should be + present in the `FeaturesDict` under the key `image`. value_range: value range of the images. Common examples include `(0, 255)` and `(0, 1)`. scale: how large to scale the images in the gallery @@ -74,9 +106,6 @@ def plot_image_gallery( I.e. passing: `[patches.Patch(color='red', label='mylabel')]` will produce a legend with a single red patch and the label 'mylabel'. - Note: - If using a `tf.data.Dataset`, it is important that the images present in - the `FeaturesDict` should have the key `image`. """ assert_matplotlib_installed("plot_bounding_box_gallery") @@ -89,65 +118,28 @@ def plot_image_gallery( "to be true." ) - def unpack_images(inputs): - return inputs["image"] - - # Calculate appropriate number of rows and columns - if rows is None or cols is None: - if isinstance(images, tf.data.Dataset): - sample = next(iter(images.take(1))) - sample_shape = sample["image"].shape - if len(sample_shape) == 4: - batch_size = sample_shape[0] - else: - raise ValueError( - "Passed `tf.data.Dataset` does not appear to be batched. Please batch using the `.batch().`" - ) - - images = images.map(unpack_images) - images = images.take(batch_size) - images = next(iter(images)) - else: - sample_shape = images.shape - if len(sample_shape) == 4: - batch_size = sample_shape[0] - else: - raise ValueError( - f"`plot_image_gallery` received unbatched images and `cols` and `rows` " - "were both `None`. Either images should be batched, or `cols` and `rows` should be specified." - ) - - elif rows is not None and cols is not None: - if isinstance(images, tf.data.Dataset): - batch_size = rows * cols - - sample = next(iter(images.take(1))) - sample_shape = sample["image"].shape - - if len(sample_shape) == 4: - images = images.unbatch() - - images = images.ragged_batch(batch_size=batch_size) - - images = images.map(unpack_images) - images = images.take(batch_size) - images = next(iter(images)) - else: - batch_size = rows * cols - images = images[:batch_size, ...] + if isinstance(images, tf.data.Dataset): + sample = next(iter(images.take(1))) + batch_size = ( + sample["image"].shape[0] if len(sample["image"].shape) == 4 else 1 + ) # batch_size from within passed `tf.data.Dataset` else: - raise ValueError( - "plot_image_gallery() expects `tf.data.Dataset` to be batched if rows or cols are not specified." - ) + batch_size = ( + images.shape[0] if len(images.shape) == 4 else 1 + ) # batch_size from np.array or single image + + rows = rows or int(math.ceil(math.sqrt(batch_size))) + cols = cols or int(math.ceil(batch_size // rows)) - rows = int(math.ceil(batch_size**0.5)) - cols = int(math.ceil(batch_size // rows)) + num_images = rows * cols + images = _extract_image_batch(images, num_images, batch_size) # Generate subplots fig, axes = plt.subplots( nrows=rows, ncols=cols, figsize=(cols * scale, rows * scale), + frameon=False, layout="tight", squeeze=True, sharex="row", @@ -155,6 +147,10 @@ def unpack_images(inputs): ) fig.subplots_adjust(wspace=0, hspace=0) + if isinstance(axes, np.ndarray) and len(axes.shape) == 1: + expand_axis = 0 if rows == 1 else -1 + axes = np.expand_dims(axes, expand_axis) + if legend_handles is not None: fig.legend(handles=legend_handles, loc="lower center") @@ -167,7 +163,9 @@ def unpack_images(inputs): for row in range(rows): for col in range(cols): index = row * cols + col - current_axis = axes[row, col] + current_axis = ( + axes[row, col] if isinstance(axes, np.ndarray) else axes + ) current_axis.imshow(images[index].astype("uint8")) current_axis.margins(x=0, y=0) current_axis.axis("off") From 0833bc8d61105e6c2dc2d1613039ad12f26939b0 Mon Sep 17 00:00:00 2001 From: Suvaditya Mukherjee Date: Mon, 26 Jun 2023 19:46:55 +0530 Subject: [PATCH 7/9] chore: fix with isort Signed-off-by: Suvaditya Mukherjee --- keras_cv/visualization/plot_image_gallery.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/keras_cv/visualization/plot_image_gallery.py b/keras_cv/visualization/plot_image_gallery.py index a05043ca5a..053e78bf99 100644 --- a/keras_cv/visualization/plot_image_gallery.py +++ b/keras_cv/visualization/plot_image_gallery.py @@ -13,9 +13,11 @@ # limitations under the License. import math -import keras_cv + import numpy as np import tensorflow as tf + +import keras_cv from keras_cv import utils from keras_cv.utils import assert_matplotlib_installed From 15b72bc44337e2655775a6de95ffa9d35702efd8 Mon Sep 17 00:00:00 2001 From: Suvaditya Mukherjee Date: Mon, 26 Jun 2023 20:10:52 +0530 Subject: [PATCH 8/9] chore: fix linting on example Signed-off-by: Suvaditya Mukherjee --- examples/visualization/plot_image_gallery.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/visualization/plot_image_gallery.py b/examples/visualization/plot_image_gallery.py index b1cd824fa6..73b35fe93d 100644 --- a/examples/visualization/plot_image_gallery.py +++ b/examples/visualization/plot_image_gallery.py @@ -11,8 +11,9 @@ Plotting images from a TensorFlow dataset is easy with KerasCV. Behold: """ -import tensorflow_datasets as tfds import numpy as np +import tensorflow_datasets as tfds + import keras_cv train_ds = tfds.load( From 701661657a8d20b2b58c0b76d64d790ca61c1918 Mon Sep 17 00:00:00 2001 From: Suvaditya Mukherjee Date: Mon, 26 Jun 2023 20:26:06 +0530 Subject: [PATCH 9/9] chore: linting issues with black and isort solved for example and main file Signed-off-by: Suvaditya Mukherjee --- examples/visualization/plot_image_gallery.py | 3 ++- keras_cv/visualization/plot_image_gallery.py | 16 ++++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/visualization/plot_image_gallery.py b/examples/visualization/plot_image_gallery.py index 73b35fe93d..17197aca4d 100644 --- a/examples/visualization/plot_image_gallery.py +++ b/examples/visualization/plot_image_gallery.py @@ -1,6 +1,7 @@ """ Title: Plot an image gallery -Author: [lukewood](https://lukewood.xyz), updated by [Suvaditya Mukherjee](https://twitter.com/halcyonrayes) +Author: [lukewood](https://lukewood.xyz), updated by +[Suvaditya Mukherjee](https://twitter.com/halcyonrayes) Date created: 2022/10/16 Last modified: 2022/06/24 Description: Visualize ground truth and predicted bounding boxes for a given diff --git a/keras_cv/visualization/plot_image_gallery.py b/keras_cv/visualization/plot_image_gallery.py index 053e78bf99..5926929b01 100644 --- a/keras_cv/visualization/plot_image_gallery.py +++ b/keras_cv/visualization/plot_image_gallery.py @@ -47,7 +47,8 @@ def unpack_images(inputs): else: if len(images.shape) != 4: raise ValueError( - "`plot_images_gallery()` requires you to batch your `np.array` samples together." + "`plot_images_gallery()` requires you to " + "batch your `np.array` samples together." ) else: num_samples = ( @@ -90,14 +91,17 @@ def plot_image_gallery( ![example gallery](https://i.imgur.com/r0ndse0.png) Args: - images: a Tensor, `tf.data.Dataset` or NumPy array containing images to show in the - gallery. Note: If using a `tf.data.Dataset`, images should be - present in the `FeaturesDict` under the key `image`. + images: a Tensor, `tf.data.Dataset` or NumPy array containing images + to show in the gallery. Note: If using a `tf.data.Dataset`, + images should be present in the `FeaturesDict` under + the key `image`. value_range: value range of the images. Common examples include `(0, 255)` and `(0, 1)`. scale: how large to scale the images in the gallery - rows: (Optional) number of rows in the gallery to show. Required if inputs are unbatched. - cols: (Optional) number of columns in the gallery to show. Required if inputs are unbatched. + rows: (Optional) number of rows in the gallery to show. + Required if inputs are unbatched. + cols: (Optional) number of columns in the gallery to show. + Required if inputs are unbatched. path: (Optional) path to save the resulting gallery to. show: (Optional) whether to show the gallery of images. transparent: (Optional) whether to give the image a transparent