Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_cv/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from keras_cv.layers.fusedmbconv import FusedMBConvBlock
from keras_cv.layers.mbconv import MBConvBlock
from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator
from keras_cv.layers.object_detection.box_matcher import BoxMatcher
from keras_cv.layers.object_detection.multi_class_non_max_suppression import (
MultiClassNonMaxSuppression,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"Argmax-based box matching"

from typing import List
from typing import Tuple

import tensorflow as tf


@tf.keras.utils.register_keras_serializable(package="keras_cv")
class ArgmaxBoxMatcher(tf.keras.layers.Layer):
class BoxMatcher(tf.keras.layers.Layer):
"""Box matching logic based on argmax of highest value (e.g., IOU).

This class computes matches from a similarity matrix. Each row will be
Expand Down Expand Up @@ -69,8 +66,8 @@ class ArgmaxBoxMatcher(tf.keras.layers.Layer):
Usage:

```python
box_matcher = keras_cv.ops.ArgmaxBoxMatcher([0.3, 0.7], [-1, 0, 1])
iou_metric = keras_cv.bounding_box.compute_iou(anchors, gt_boxes)
box_matcher = keras_cv.layers.BoxMatcher([0.3, 0.7], [-1, 0, 1])
iou_metric = keras_cv.bounding_box.compute_iou(anchors, boxes)
matched_columns, matched_match_values = box_matcher(iou_metric)
cls_mask = tf.less_equal(matched_match_values, 0)
```
Expand Down Expand Up @@ -135,7 +132,7 @@ def _match_when_cols_are_empty():
storing the match type indicator (e.g. positive or negative
or ignored match).
"""
with tf.name_scope("empty_gt_boxes"):
with tf.name_scope("empty_boxes"):
matched_columns = tf.zeros([batch_size, num_rows], dtype=tf.int32)
matched_values = -tf.ones([batch_size, num_rows], dtype=tf.int32)
return matched_columns, matched_values
Expand All @@ -149,7 +146,7 @@ def _match_when_cols_are_non_empty():
storing the match type indicator (e.g. positive or negative
or ignored match).
"""
with tf.name_scope("non_empty_gt_boxes"):
with tf.name_scope("non_empty_boxes"):
matched_columns = tf.argmax(
similarity_matrix, axis=-1, output_type=tf.int32
)
Expand Down Expand Up @@ -207,11 +204,11 @@ def _match_when_cols_are_non_empty():

return matched_columns, matched_values

num_gt_boxes = (
num_boxes = (
similarity_matrix.shape.as_list()[-1] or tf.shape(similarity_matrix)[-1]
)
matched_columns, matched_values = tf.cond(
pred=tf.greater(num_gt_boxes, 0),
pred=tf.greater(num_boxes, 0),
true_fn=_match_when_cols_are_non_empty,
false_fn=_match_when_cols_are_empty,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@

import tensorflow as tf

from keras_cv.ops.box_matcher import ArgmaxBoxMatcher
from keras_cv.layers.object_detection.box_matcher import BoxMatcher


class ArgmaxBoxMatcherTest(tf.test.TestCase):
class BoxMatcherTest(tf.test.TestCase):
def test_box_matcher_invalid_length(self):
fg_threshold = 0.5
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0

with self.assertRaisesRegex(ValueError, "must be len"):
_ = ArgmaxBoxMatcher(
_ = BoxMatcher(
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
match_values=[-3, -2, -1],
)
Expand All @@ -35,7 +35,7 @@ def test_box_matcher_unsorted_thresholds(self):
bg_thresh_lo = 0.0

with self.assertRaisesRegex(ValueError, "must be sorted"):
_ = ArgmaxBoxMatcher(
_ = BoxMatcher(
thresholds=[bg_thresh_hi, bg_thresh_lo, fg_threshold],
match_values=[-3, -2, -1, 1],
)
Expand All @@ -47,7 +47,7 @@ def test_box_matcher_unbatched(self):
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0

matcher = ArgmaxBoxMatcher(
matcher = BoxMatcher(
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
match_values=[-3, -2, -1, 1],
)
Expand All @@ -67,7 +67,7 @@ def test_box_matcher_batched(self):
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0

matcher = ArgmaxBoxMatcher(
matcher = BoxMatcher(
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
match_values=[-3, -2, -1, 1],
)
Expand All @@ -90,7 +90,7 @@ def test_box_matcher_force_match(self):
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0

matcher = ArgmaxBoxMatcher(
matcher = BoxMatcher(
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
match_values=[-3, -2, -1, 1],
force_match_for_each_col=True,
Expand All @@ -113,7 +113,7 @@ def test_box_matcher_empty_gt_boxes(self):
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0

matcher = ArgmaxBoxMatcher(
matcher = BoxMatcher(
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
match_values=[-3, -2, -1, 1],
)
Expand Down
6 changes: 3 additions & 3 deletions keras_cv/layers/object_detection/retina_net_label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from tensorflow.keras import layers

from keras_cv import bounding_box
from keras_cv.ops import box_matcher
from keras_cv.ops import target_gather
from keras_cv.layers.object_detection import box_matcher
from keras_cv.utils import target_gather


class RetinaNetLabelEncoder(layers.Layer):
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
)
self.positive_threshold = positive_threshold
self.negative_threshold = negative_threshold
self.box_matcher = box_matcher.ArgmaxBoxMatcher(
self.box_matcher = box_matcher.BoxMatcher(
thresholds=[negative_threshold, positive_threshold],
match_values=[-1, -2, 1],
force_match_for_each_col=False,
Expand Down
12 changes: 6 additions & 6 deletions keras_cv/layers/object_detection/roi_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

from keras_cv import bounding_box
from keras_cv.bounding_box import iou
from keras_cv.ops import box_matcher
from keras_cv.ops import sampling
from keras_cv.ops import target_gather
from keras_cv.layers.object_detection import box_matcher
from keras_cv.layers.object_detection import sampling
from keras_cv.layers.object_detection import target_gather


@tf.keras.utils.register_keras_serializable(package="keras_cv")
Expand All @@ -44,7 +44,7 @@ class _ROISampler(tf.keras.layers.Layer):
bounding_box_format: The format of bounding boxes to generate. Refer
[to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/)
for more details on supported bounding box formats.
roi_matcher: a `ArgmaxBoxMatcher` object that matches proposals
roi_matcher: a `BoxMatcher` object that matches proposals
with ground truth boxes. the positive match must be 1 and negative match must be -1.
Such assumption is not being validated here.
positive_fraction: the positive ratio w.r.t `num_sampled_rois`. Defaults to 0.25.
Expand All @@ -59,7 +59,7 @@ class _ROISampler(tf.keras.layers.Layer):
def __init__(
self,
bounding_box_format: str,
roi_matcher: box_matcher.ArgmaxBoxMatcher,
roi_matcher: box_matcher.BoxMatcher,
positive_fraction: float = 0.25,
background_class: int = 0,
num_sampled_rois: int = 256,
Expand Down Expand Up @@ -205,5 +205,5 @@ def get_config(self):
@classmethod
def from_config(cls, config, custom_objects=None):
roi_matcher_config = config.pop("roi_matcher")
roi_matcher = box_matcher.ArgmaxBoxMatcher(**roi_matcher_config)
roi_matcher = box_matcher.BoxMatcher(**roi_matcher_config)
return cls(roi_matcher=roi_matcher, **config)
16 changes: 8 additions & 8 deletions keras_cv/layers/object_detection/roi_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

import tensorflow as tf

from keras_cv.layers.object_detection.box_matcher import BoxMatcher
from keras_cv.layers.object_detection.roi_sampler import _ROISampler
from keras_cv.ops.box_matcher import ArgmaxBoxMatcher


class ROISamplerTest(tf.test.TestCase):
def test_roi_sampler(self):
box_matcher = ArgmaxBoxMatcher(thresholds=[0.3], match_values=[-1, 1])
box_matcher = BoxMatcher(thresholds=[0.3], match_values=[-1, 1])
roi_sampler = _ROISampler(
bounding_box_format="xyxy",
roi_matcher=box_matcher,
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_roi_sampler(self):
)

def test_roi_sampler_small_threshold(self):
box_matcher = ArgmaxBoxMatcher(thresholds=[0.1], match_values=[-1, 1])
box_matcher = BoxMatcher(thresholds=[0.1], match_values=[-1, 1])
roi_sampler = _ROISampler(
bounding_box_format="xyxy",
roi_matcher=box_matcher,
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_roi_sampler_small_threshold(self):

def test_roi_sampler_large_threshold(self):
# the 2nd roi and 2nd gt box has IOU of 0.923, setting positive_threshold to 0.95 to ignore it
box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1])
box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1])
roi_sampler = _ROISampler(
bounding_box_format="xyxy",
roi_matcher=box_matcher,
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_roi_sampler_large_threshold(self):

def test_roi_sampler_large_threshold_custom_bg_class(self):
# the 2nd roi and 2nd gt box has IOU of 0.923, setting positive_threshold to 0.95 to ignore it
box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1])
box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1])
roi_sampler = _ROISampler(
bounding_box_format="xyxy",
roi_matcher=box_matcher,
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_roi_sampler_large_threshold_custom_bg_class(self):

def test_roi_sampler_large_threshold_append_gt_boxes(self):
# the 2nd roi and 2nd gt box has IOU of 0.923, setting positive_threshold to 0.95 to ignore it
box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1])
box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1])
roi_sampler = _ROISampler(
bounding_box_format="xyxy",
roi_matcher=box_matcher,
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_roi_sampler_large_threshold_append_gt_boxes(self):
self.assertAllGreaterEqual(tf.reduce_min(sampled_gt_classes), 0)

def test_roi_sampler_large_num_sampled_rois(self):
box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1])
box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1])
roi_sampler = _ROISampler(
bounding_box_format="xyxy",
roi_matcher=box_matcher,
Expand All @@ -227,7 +227,7 @@ def test_roi_sampler_large_num_sampled_rois(self):
_, _, _ = roi_sampler(rois, gt_boxes, gt_classes)

def test_serialization(self):
box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1])
box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1])
roi_sampler = _ROISampler(
bounding_box_format="xyxy",
roi_matcher=box_matcher,
Expand Down
8 changes: 4 additions & 4 deletions keras_cv/layers/object_detection/rpn_label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

from keras_cv import bounding_box
from keras_cv.bounding_box import iou
from keras_cv.ops import box_matcher
from keras_cv.ops import sampling
from keras_cv.ops import target_gather
from keras_cv.layers.object_detection import box_matcher
from keras_cv.layers.object_detection import sampling
from keras_cv.utils import target_gather


@tf.keras.utils.register_keras_serializable(package="keras_cv")
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
self.negative_threshold = negative_threshold
self.samples_per_image = samples_per_image
self.positive_fraction = positive_fraction
self.box_matcher = box_matcher.ArgmaxBoxMatcher(
self.box_matcher = box_matcher.BoxMatcher(
thresholds=[negative_threshold, positive_threshold],
match_values=[-1, -2, 1],
force_match_for_each_col=False,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import tensorflow as tf

from keras_cv.ops.sampling import balanced_sample
from keras_cv.layers.object_detection.sampling import balanced_sample


class BalancedSamplingTest(tf.test.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import tensorflow as tf


def _target_gather(
targets: tf.Tensor,
indices: tf.Tensor,
mask: Optional[tf.Tensor] = None,
mask_val: Optional[float] = 0.0,
mask=None,
mask_val=0.0,
):
"""A utility function wrapping tf.gather, which deals with:
1) both batched and unbatched `targets`
Expand Down
Loading