diff --git a/keras_cv/layers/__init__.py b/keras_cv/layers/__init__.py index 8a9df9f125..ec17913dc4 100644 --- a/keras_cv/layers/__init__.py +++ b/keras_cv/layers/__init__.py @@ -21,6 +21,7 @@ from keras_cv.layers.fusedmbconv import FusedMBConvBlock from keras_cv.layers.mbconv import MBConvBlock from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator +from keras_cv.layers.object_detection.box_matcher import BoxMatcher from keras_cv.layers.object_detection.multi_class_non_max_suppression import ( MultiClassNonMaxSuppression, ) diff --git a/keras_cv/ops/box_matcher.py b/keras_cv/layers/object_detection/box_matcher.py similarity index 96% rename from keras_cv/ops/box_matcher.py rename to keras_cv/layers/object_detection/box_matcher.py index bcf52afcdf..b4d54bc95f 100644 --- a/keras_cv/ops/box_matcher.py +++ b/keras_cv/layers/object_detection/box_matcher.py @@ -11,9 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"Argmax-based box matching" - from typing import List from typing import Tuple @@ -21,7 +18,7 @@ @tf.keras.utils.register_keras_serializable(package="keras_cv") -class ArgmaxBoxMatcher(tf.keras.layers.Layer): +class BoxMatcher(tf.keras.layers.Layer): """Box matching logic based on argmax of highest value (e.g., IOU). This class computes matches from a similarity matrix. Each row will be @@ -69,8 +66,8 @@ class ArgmaxBoxMatcher(tf.keras.layers.Layer): Usage: ```python - box_matcher = keras_cv.ops.ArgmaxBoxMatcher([0.3, 0.7], [-1, 0, 1]) - iou_metric = keras_cv.bounding_box.compute_iou(anchors, gt_boxes) + box_matcher = keras_cv.layers.BoxMatcher([0.3, 0.7], [-1, 0, 1]) + iou_metric = keras_cv.bounding_box.compute_iou(anchors, boxes) matched_columns, matched_match_values = box_matcher(iou_metric) cls_mask = tf.less_equal(matched_match_values, 0) ``` @@ -135,7 +132,7 @@ def _match_when_cols_are_empty(): storing the match type indicator (e.g. positive or negative or ignored match). """ - with tf.name_scope("empty_gt_boxes"): + with tf.name_scope("empty_boxes"): matched_columns = tf.zeros([batch_size, num_rows], dtype=tf.int32) matched_values = -tf.ones([batch_size, num_rows], dtype=tf.int32) return matched_columns, matched_values @@ -149,7 +146,7 @@ def _match_when_cols_are_non_empty(): storing the match type indicator (e.g. positive or negative or ignored match). """ - with tf.name_scope("non_empty_gt_boxes"): + with tf.name_scope("non_empty_boxes"): matched_columns = tf.argmax( similarity_matrix, axis=-1, output_type=tf.int32 ) @@ -207,11 +204,11 @@ def _match_when_cols_are_non_empty(): return matched_columns, matched_values - num_gt_boxes = ( + num_boxes = ( similarity_matrix.shape.as_list()[-1] or tf.shape(similarity_matrix)[-1] ) matched_columns, matched_values = tf.cond( - pred=tf.greater(num_gt_boxes, 0), + pred=tf.greater(num_boxes, 0), true_fn=_match_when_cols_are_non_empty, false_fn=_match_when_cols_are_empty, ) diff --git a/keras_cv/ops/box_matcher_test.py b/keras_cv/layers/object_detection/box_matcher_test.py similarity index 93% rename from keras_cv/ops/box_matcher_test.py rename to keras_cv/layers/object_detection/box_matcher_test.py index 686ec1004b..7144f60059 100644 --- a/keras_cv/ops/box_matcher_test.py +++ b/keras_cv/layers/object_detection/box_matcher_test.py @@ -14,17 +14,17 @@ import tensorflow as tf -from keras_cv.ops.box_matcher import ArgmaxBoxMatcher +from keras_cv.layers.object_detection.box_matcher import BoxMatcher -class ArgmaxBoxMatcherTest(tf.test.TestCase): +class BoxMatcherTest(tf.test.TestCase): def test_box_matcher_invalid_length(self): fg_threshold = 0.5 bg_thresh_hi = 0.2 bg_thresh_lo = 0.0 with self.assertRaisesRegex(ValueError, "must be len"): - _ = ArgmaxBoxMatcher( + _ = BoxMatcher( thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold], match_values=[-3, -2, -1], ) @@ -35,7 +35,7 @@ def test_box_matcher_unsorted_thresholds(self): bg_thresh_lo = 0.0 with self.assertRaisesRegex(ValueError, "must be sorted"): - _ = ArgmaxBoxMatcher( + _ = BoxMatcher( thresholds=[bg_thresh_hi, bg_thresh_lo, fg_threshold], match_values=[-3, -2, -1, 1], ) @@ -47,7 +47,7 @@ def test_box_matcher_unbatched(self): bg_thresh_hi = 0.2 bg_thresh_lo = 0.0 - matcher = ArgmaxBoxMatcher( + matcher = BoxMatcher( thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold], match_values=[-3, -2, -1, 1], ) @@ -67,7 +67,7 @@ def test_box_matcher_batched(self): bg_thresh_hi = 0.2 bg_thresh_lo = 0.0 - matcher = ArgmaxBoxMatcher( + matcher = BoxMatcher( thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold], match_values=[-3, -2, -1, 1], ) @@ -90,7 +90,7 @@ def test_box_matcher_force_match(self): bg_thresh_hi = 0.2 bg_thresh_lo = 0.0 - matcher = ArgmaxBoxMatcher( + matcher = BoxMatcher( thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold], match_values=[-3, -2, -1, 1], force_match_for_each_col=True, @@ -113,7 +113,7 @@ def test_box_matcher_empty_gt_boxes(self): bg_thresh_hi = 0.2 bg_thresh_lo = 0.0 - matcher = ArgmaxBoxMatcher( + matcher = BoxMatcher( thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold], match_values=[-3, -2, -1, 1], ) diff --git a/keras_cv/layers/object_detection/retina_net_label_encoder.py b/keras_cv/layers/object_detection/retina_net_label_encoder.py index 717e2d3b33..769731de6c 100644 --- a/keras_cv/layers/object_detection/retina_net_label_encoder.py +++ b/keras_cv/layers/object_detection/retina_net_label_encoder.py @@ -16,8 +16,8 @@ from tensorflow.keras import layers from keras_cv import bounding_box -from keras_cv.ops import box_matcher -from keras_cv.ops import target_gather +from keras_cv.layers.object_detection import box_matcher +from keras_cv.utils import target_gather class RetinaNetLabelEncoder(layers.Layer): @@ -66,7 +66,7 @@ def __init__( ) self.positive_threshold = positive_threshold self.negative_threshold = negative_threshold - self.box_matcher = box_matcher.ArgmaxBoxMatcher( + self.box_matcher = box_matcher.BoxMatcher( thresholds=[negative_threshold, positive_threshold], match_values=[-1, -2, 1], force_match_for_each_col=False, diff --git a/keras_cv/layers/object_detection/roi_sampler.py b/keras_cv/layers/object_detection/roi_sampler.py index c3f77a5909..ba8dafdc20 100644 --- a/keras_cv/layers/object_detection/roi_sampler.py +++ b/keras_cv/layers/object_detection/roi_sampler.py @@ -16,9 +16,9 @@ from keras_cv import bounding_box from keras_cv.bounding_box import iou -from keras_cv.ops import box_matcher -from keras_cv.ops import sampling -from keras_cv.ops import target_gather +from keras_cv.layers.object_detection import box_matcher +from keras_cv.layers.object_detection import sampling +from keras_cv.layers.object_detection import target_gather @tf.keras.utils.register_keras_serializable(package="keras_cv") @@ -44,7 +44,7 @@ class _ROISampler(tf.keras.layers.Layer): bounding_box_format: The format of bounding boxes to generate. Refer [to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/) for more details on supported bounding box formats. - roi_matcher: a `ArgmaxBoxMatcher` object that matches proposals + roi_matcher: a `BoxMatcher` object that matches proposals with ground truth boxes. the positive match must be 1 and negative match must be -1. Such assumption is not being validated here. positive_fraction: the positive ratio w.r.t `num_sampled_rois`. Defaults to 0.25. @@ -59,7 +59,7 @@ class _ROISampler(tf.keras.layers.Layer): def __init__( self, bounding_box_format: str, - roi_matcher: box_matcher.ArgmaxBoxMatcher, + roi_matcher: box_matcher.BoxMatcher, positive_fraction: float = 0.25, background_class: int = 0, num_sampled_rois: int = 256, @@ -205,5 +205,5 @@ def get_config(self): @classmethod def from_config(cls, config, custom_objects=None): roi_matcher_config = config.pop("roi_matcher") - roi_matcher = box_matcher.ArgmaxBoxMatcher(**roi_matcher_config) + roi_matcher = box_matcher.BoxMatcher(**roi_matcher_config) return cls(roi_matcher=roi_matcher, **config) diff --git a/keras_cv/layers/object_detection/roi_sampler_test.py b/keras_cv/layers/object_detection/roi_sampler_test.py index 4b3e752326..45e04ae390 100644 --- a/keras_cv/layers/object_detection/roi_sampler_test.py +++ b/keras_cv/layers/object_detection/roi_sampler_test.py @@ -14,13 +14,13 @@ import tensorflow as tf +from keras_cv.layers.object_detection.box_matcher import BoxMatcher from keras_cv.layers.object_detection.roi_sampler import _ROISampler -from keras_cv.ops.box_matcher import ArgmaxBoxMatcher class ROISamplerTest(tf.test.TestCase): def test_roi_sampler(self): - box_matcher = ArgmaxBoxMatcher(thresholds=[0.3], match_values=[-1, 1]) + box_matcher = BoxMatcher(thresholds=[0.3], match_values=[-1, 1]) roi_sampler = _ROISampler( bounding_box_format="xyxy", roi_matcher=box_matcher, @@ -57,7 +57,7 @@ def test_roi_sampler(self): ) def test_roi_sampler_small_threshold(self): - box_matcher = ArgmaxBoxMatcher(thresholds=[0.1], match_values=[-1, 1]) + box_matcher = BoxMatcher(thresholds=[0.1], match_values=[-1, 1]) roi_sampler = _ROISampler( bounding_box_format="xyxy", roi_matcher=box_matcher, @@ -106,7 +106,7 @@ def test_roi_sampler_small_threshold(self): def test_roi_sampler_large_threshold(self): # the 2nd roi and 2nd gt box has IOU of 0.923, setting positive_threshold to 0.95 to ignore it - box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1]) + box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) roi_sampler = _ROISampler( bounding_box_format="xyxy", roi_matcher=box_matcher, @@ -139,7 +139,7 @@ def test_roi_sampler_large_threshold(self): def test_roi_sampler_large_threshold_custom_bg_class(self): # the 2nd roi and 2nd gt box has IOU of 0.923, setting positive_threshold to 0.95 to ignore it - box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1]) + box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) roi_sampler = _ROISampler( bounding_box_format="xyxy", roi_matcher=box_matcher, @@ -173,7 +173,7 @@ def test_roi_sampler_large_threshold_custom_bg_class(self): def test_roi_sampler_large_threshold_append_gt_boxes(self): # the 2nd roi and 2nd gt box has IOU of 0.923, setting positive_threshold to 0.95 to ignore it - box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1]) + box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) roi_sampler = _ROISampler( bounding_box_format="xyxy", roi_matcher=box_matcher, @@ -204,7 +204,7 @@ def test_roi_sampler_large_threshold_append_gt_boxes(self): self.assertAllGreaterEqual(tf.reduce_min(sampled_gt_classes), 0) def test_roi_sampler_large_num_sampled_rois(self): - box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1]) + box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) roi_sampler = _ROISampler( bounding_box_format="xyxy", roi_matcher=box_matcher, @@ -227,7 +227,7 @@ def test_roi_sampler_large_num_sampled_rois(self): _, _, _ = roi_sampler(rois, gt_boxes, gt_classes) def test_serialization(self): - box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1]) + box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1]) roi_sampler = _ROISampler( bounding_box_format="xyxy", roi_matcher=box_matcher, diff --git a/keras_cv/layers/object_detection/rpn_label_encoder.py b/keras_cv/layers/object_detection/rpn_label_encoder.py index e64e68ff13..904ba14efb 100644 --- a/keras_cv/layers/object_detection/rpn_label_encoder.py +++ b/keras_cv/layers/object_detection/rpn_label_encoder.py @@ -18,9 +18,9 @@ from keras_cv import bounding_box from keras_cv.bounding_box import iou -from keras_cv.ops import box_matcher -from keras_cv.ops import sampling -from keras_cv.ops import target_gather +from keras_cv.layers.object_detection import box_matcher +from keras_cv.layers.object_detection import sampling +from keras_cv.utils import target_gather @tf.keras.utils.register_keras_serializable(package="keras_cv") @@ -74,7 +74,7 @@ def __init__( self.negative_threshold = negative_threshold self.samples_per_image = samples_per_image self.positive_fraction = positive_fraction - self.box_matcher = box_matcher.ArgmaxBoxMatcher( + self.box_matcher = box_matcher.BoxMatcher( thresholds=[negative_threshold, positive_threshold], match_values=[-1, -2, 1], force_match_for_each_col=False, diff --git a/keras_cv/ops/sampling.py b/keras_cv/layers/object_detection/sampling.py similarity index 100% rename from keras_cv/ops/sampling.py rename to keras_cv/layers/object_detection/sampling.py diff --git a/keras_cv/ops/sampling_test.py b/keras_cv/layers/object_detection/sampling_test.py similarity index 98% rename from keras_cv/ops/sampling_test.py rename to keras_cv/layers/object_detection/sampling_test.py index 42e9eddbd3..7f22ead5ba 100644 --- a/keras_cv/ops/sampling_test.py +++ b/keras_cv/layers/object_detection/sampling_test.py @@ -14,7 +14,7 @@ import tensorflow as tf -from keras_cv.ops.sampling import balanced_sample +from keras_cv.layers.object_detection.sampling import balanced_sample class BalancedSamplingTest(tf.test.TestCase): diff --git a/keras_cv/ops/target_gather.py b/keras_cv/layers/object_detection/target_gather.py similarity index 97% rename from keras_cv/ops/target_gather.py rename to keras_cv/layers/object_detection/target_gather.py index 9044199b1f..772faeeae0 100644 --- a/keras_cv/ops/target_gather.py +++ b/keras_cv/layers/object_detection/target_gather.py @@ -12,16 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional - import tensorflow as tf def _target_gather( targets: tf.Tensor, indices: tf.Tensor, - mask: Optional[tf.Tensor] = None, - mask_val: Optional[float] = 0.0, + mask=None, + mask_val=0.0, ): """A utility function wrapping tf.gather, which deals with: 1) both batched and unbatched `targets` diff --git a/keras_cv/layers/object_detection/target_gather_test.py b/keras_cv/layers/object_detection/target_gather_test.py new file mode 100644 index 0000000000..1e5b3b5938 --- /dev/null +++ b/keras_cv/layers/object_detection/target_gather_test.py @@ -0,0 +1,117 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +from keras_cv.layers.object_detection.target_gather import _target_gather + + +class TargetGatherTest(tf.test.TestCase): + def test_target_gather_boxes_batched(self): + target_boxes = tf.constant( + [[0, 0, 5, 5], [0, 5, 5, 10], [5, 0, 10, 5], [5, 5, 10, 10]] + ) + target_boxes = target_boxes[tf.newaxis, ...] + indices = tf.constant([[0, 2]], dtype=tf.int32) + expected_boxes = tf.constant([[0, 0, 5, 5], [5, 0, 10, 5]]) + expected_boxes = expected_boxes[tf.newaxis, ...] + res = _target_gather(target_boxes, indices) + self.assertAllClose(expected_boxes, res) + + def test_target_gather_boxes_unbatched(self): + target_boxes = tf.constant( + [[0, 0, 5, 5], [0, 5, 5, 10], [5, 0, 10, 5], [5, 5, 10, 10]] + ) + indices = tf.constant([0, 2], dtype=tf.int32) + expected_boxes = tf.constant([[0, 0, 5, 5], [5, 0, 10, 5]]) + res = _target_gather(target_boxes, indices) + self.assertAllClose(expected_boxes, res) + + def test_target_gather_classes_batched(self): + target_classes = tf.constant([[1, 2, 3, 4]]) + target_classes = target_classes[..., tf.newaxis] + indices = tf.constant([[0, 2]], dtype=tf.int32) + expected_classes = tf.constant([[1, 3]]) + expected_classes = expected_classes[..., tf.newaxis] + res = _target_gather(target_classes, indices) + self.assertAllClose(expected_classes, res) + + def test_target_gather_classes_unbatched(self): + target_classes = tf.constant([1, 2, 3, 4]) + target_classes = target_classes[..., tf.newaxis] + indices = tf.constant([0, 2], dtype=tf.int32) + expected_classes = tf.constant([1, 3]) + expected_classes = expected_classes[..., tf.newaxis] + res = _target_gather(target_classes, indices) + self.assertAllClose(expected_classes, res) + + def test_target_gather_classes_batched_with_mask(self): + target_classes = tf.constant([[1, 2, 3, 4]]) + target_classes = target_classes[..., tf.newaxis] + indices = tf.constant([[0, 2]], dtype=tf.int32) + masks = tf.constant(([[False, True]])) + masks = masks[..., tf.newaxis] + # the second element is masked + expected_classes = tf.constant([[1, 0]]) + expected_classes = expected_classes[..., tf.newaxis] + res = _target_gather(target_classes, indices, masks) + self.assertAllClose(expected_classes, res) + + def test_target_gather_classes_batched_with_mask_val(self): + target_classes = tf.constant([[1, 2, 3, 4]]) + target_classes = target_classes[..., tf.newaxis] + indices = tf.constant([[0, 2]], dtype=tf.int32) + masks = tf.constant(([[False, True]])) + masks = masks[..., tf.newaxis] + # the second element is masked + expected_classes = tf.constant([[1, -1]]) + expected_classes = expected_classes[..., tf.newaxis] + res = _target_gather(target_classes, indices, masks, -1) + self.assertAllClose(expected_classes, res) + + def test_target_gather_classes_unbatched_with_mask(self): + target_classes = tf.constant([1, 2, 3, 4]) + target_classes = target_classes[..., tf.newaxis] + indices = tf.constant([0, 2], dtype=tf.int32) + masks = tf.constant([False, True]) + masks = masks[..., tf.newaxis] + expected_classes = tf.constant([1, 0]) + expected_classes = expected_classes[..., tf.newaxis] + res = _target_gather(target_classes, indices, masks) + self.assertAllClose(expected_classes, res) + + def test_target_gather_with_empty_targets(self): + target_classes = tf.constant([]) + target_classes = target_classes[..., tf.newaxis] + indices = tf.constant([0, 2], dtype=tf.int32) + # return all 0s since input is empty + expected_classes = tf.constant([0, 0]) + expected_classes = expected_classes[..., tf.newaxis] + res = _target_gather(target_classes, indices) + self.assertAllClose(expected_classes, res) + + def test_target_gather_classes_multi_batch(self): + target_classes = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]]) + target_classes = target_classes[..., tf.newaxis] + indices = tf.constant([[0, 2], [1, 3]], dtype=tf.int32) + expected_classes = tf.constant([[1, 3], [6, 8]]) + expected_classes = expected_classes[..., tf.newaxis] + res = _target_gather(target_classes, indices) + self.assertAllClose(expected_classes, res) + + def test_target_gather_invalid_rank(self): + targets = tf.random.normal([32, 2, 2, 2]) + indices = tf.constant([0, 1], dtype=tf.int32) + with self.assertRaisesRegex(ValueError, "larger than 3"): + _ = _target_gather(targets, indices) diff --git a/keras_cv/models/object_detection/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn.py index bf7ee01c62..dfa0c31cea 100644 --- a/keras_cv/models/object_detection/faster_rcnn.py +++ b/keras_cv/models/object_detection/faster_rcnn.py @@ -21,12 +21,12 @@ from keras_cv.bounding_box.converters import _decode_deltas_to_boxes from keras_cv.bounding_box.utils import _clip_boxes from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator +from keras_cv.layers.object_detection.box_matcher import BoxMatcher from keras_cv.layers.object_detection.roi_align import _ROIAligner from keras_cv.layers.object_detection.roi_generator import ROIGenerator from keras_cv.layers.object_detection.roi_sampler import _ROISampler from keras_cv.layers.object_detection.rpn_label_encoder import _RpnLabelEncoder from keras_cv.models.object_detection import predict_utils -from keras_cv.ops.box_matcher import ArgmaxBoxMatcher BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2] @@ -298,9 +298,7 @@ def __init__( nms_score_threshold_train=float("-inf"), nms_score_threshold_test=float("-inf"), ) - self.box_matcher = ArgmaxBoxMatcher( - thresholds=[0.0, 0.5], match_values=[-2, -1, 1] - ) + self.box_matcher = BoxMatcher(thresholds=[0.0, 0.5], match_values=[-2, -1, 1]) self.roi_sampler = _ROISampler( bounding_box_format="yxyx", roi_matcher=self.box_matcher, diff --git a/keras_cv/utils/target_gather.py b/keras_cv/utils/target_gather.py new file mode 100644 index 0000000000..772faeeae0 --- /dev/null +++ b/keras_cv/utils/target_gather.py @@ -0,0 +1,122 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + + +def _target_gather( + targets: tf.Tensor, + indices: tf.Tensor, + mask=None, + mask_val=0.0, +): + """A utility function wrapping tf.gather, which deals with: + 1) both batched and unbatched `targets` + 2) when unbatched `targets` have empty rows, the result will be filled + with `mask_val` + 3) target masking. + + Args: + targets: [N, ...] or [batch_size, N, ...] Tensor representing targets such + as boxes, keypoints, etc. + indices: [M] or [batch_size, M] int32 Tensor representing indices within + `targets` to gather. + mask: optional [M, ...] or [batch_size, M, ...] boolean + Tensor representing the masking for each target. `True` means the corresponding + entity should be masked to `mask_val`, `False` means the corresponding entity + should be the target value. + mask_val: optinal float representing the masking value if `mask` is True on + the entity. + + Returns: + targets: [M, ...] or [batch_size, M, ...] Tensor representing selected targets. + + Raise: + ValueError: If `targets` is higher than rank 3. + """ + targets_shape = targets.get_shape().as_list() + if len(targets_shape) > 3: + raise ValueError( + "`target_gather` does not support `targets` with rank " + "larger than 3, got {}".format(len(targets.shape)) + ) + + def _gather_unbatched(labels, match_indices, mask, mask_val): + """Gather based on unbatched labels and boxes.""" + num_gt_boxes = tf.shape(labels)[0] + + def _assign_when_rows_empty(): + if len(labels.shape) > 1: + mask_shape = [match_indices.shape[0], labels.shape[-1]] + else: + mask_shape = [match_indices.shape[0]] + return tf.cast(mask_val, labels.dtype) * tf.ones( + mask_shape, dtype=labels.dtype + ) + + def _assign_when_rows_not_empty(): + targets = tf.gather(labels, match_indices) + if mask is None: + return targets + else: + masked_targets = tf.cast(mask_val, labels.dtype) * tf.ones_like( + mask, dtype=labels.dtype + ) + return tf.where(mask, masked_targets, targets) + + return tf.cond( + tf.greater(num_gt_boxes, 0), + _assign_when_rows_not_empty, + _assign_when_rows_empty, + ) + + def _gather_batched(labels, match_indices, mask, mask_val): + """Gather based on batched labels.""" + batch_size = labels.shape[0] + if batch_size == 1: + if mask is not None: + result = _gather_unbatched( + tf.squeeze(labels, axis=0), + tf.squeeze(match_indices, axis=0), + tf.squeeze(mask, axis=0), + mask_val, + ) + else: + result = _gather_unbatched( + tf.squeeze(labels, axis=0), + tf.squeeze(match_indices, axis=0), + None, + mask_val, + ) + return tf.expand_dims(result, axis=0) + else: + indices_shape = tf.shape(match_indices) + indices_dtype = match_indices.dtype + batch_indices = tf.expand_dims( + tf.range(indices_shape[0], dtype=indices_dtype), axis=-1 + ) * tf.ones([1, indices_shape[-1]], dtype=indices_dtype) + gather_nd_indices = tf.stack([batch_indices, match_indices], axis=-1) + targets = tf.gather_nd(labels, gather_nd_indices) + if mask is None: + return targets + else: + masked_targets = tf.cast(mask_val, labels.dtype) * tf.ones_like( + mask, dtype=labels.dtype + ) + return tf.where(mask, masked_targets, targets) + + if len(targets_shape) <= 2: + return _gather_unbatched(targets, indices, mask, mask_val) + elif len(targets_shape) == 3: + return _gather_batched(targets, indices, mask, mask_val) diff --git a/keras_cv/ops/target_gather_test.py b/keras_cv/utils/target_gather_test.py similarity index 98% rename from keras_cv/ops/target_gather_test.py rename to keras_cv/utils/target_gather_test.py index 95d7f44d08..cdfb9d188a 100644 --- a/keras_cv/ops/target_gather_test.py +++ b/keras_cv/utils/target_gather_test.py @@ -14,7 +14,7 @@ import tensorflow as tf -from keras_cv.ops.target_gather import _target_gather +from keras_cv.utils.target_gather import _target_gather class TargetGatherTest(tf.test.TestCase):