Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_cv/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from keras_cv.layers.fusedmbconv import FusedMBConvBlock
from keras_cv.layers.mbconv import MBConvBlock
from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator
from keras_cv.layers.object_detection.box_matcher import BoxMatcher
from keras_cv.layers.object_detection.multi_class_non_max_suppression import (
MultiClassNonMaxSuppression,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"Argmax-based box matching"

from typing import List
from typing import Tuple

import tensorflow as tf


@tf.keras.utils.register_keras_serializable(package="keras_cv")
class ArgmaxBoxMatcher(tf.keras.layers.Layer):
class BoxMatcher(tf.keras.layers.Layer):
"""Box matching logic based on argmax of highest value (e.g., IOU).

This class computes matches from a similarity matrix. Each row will be
Expand Down Expand Up @@ -69,8 +66,8 @@ class ArgmaxBoxMatcher(tf.keras.layers.Layer):
Usage:

```python
box_matcher = keras_cv.ops.ArgmaxBoxMatcher([0.3, 0.7], [-1, 0, 1])
iou_metric = keras_cv.bounding_box.compute_iou(anchors, gt_boxes)
box_matcher = keras_cv.layers.BoxMatcher([0.3, 0.7], [-1, 0, 1])
iou_metric = keras_cv.bounding_box.compute_iou(anchors, boxes)
matched_columns, matched_match_values = box_matcher(iou_metric)
cls_mask = tf.less_equal(matched_match_values, 0)
```
Expand Down Expand Up @@ -135,7 +132,7 @@ def _match_when_cols_are_empty():
storing the match type indicator (e.g. positive or negative
or ignored match).
"""
with tf.name_scope("empty_gt_boxes"):
with tf.name_scope("empty_boxes"):
matched_columns = tf.zeros([batch_size, num_rows], dtype=tf.int32)
matched_values = -tf.ones([batch_size, num_rows], dtype=tf.int32)
return matched_columns, matched_values
Expand All @@ -149,7 +146,7 @@ def _match_when_cols_are_non_empty():
storing the match type indicator (e.g. positive or negative
or ignored match).
"""
with tf.name_scope("non_empty_gt_boxes"):
with tf.name_scope("non_empty_boxes"):
matched_columns = tf.argmax(
similarity_matrix, axis=-1, output_type=tf.int32
)
Expand Down Expand Up @@ -207,11 +204,11 @@ def _match_when_cols_are_non_empty():

return matched_columns, matched_values

num_gt_boxes = (
num_boxes = (
similarity_matrix.shape.as_list()[-1] or tf.shape(similarity_matrix)[-1]
)
matched_columns, matched_values = tf.cond(
pred=tf.greater(num_gt_boxes, 0),
pred=tf.greater(num_boxes, 0),
true_fn=_match_when_cols_are_non_empty,
false_fn=_match_when_cols_are_empty,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@

import tensorflow as tf

from keras_cv.ops.box_matcher import ArgmaxBoxMatcher
from keras_cv.layers.object_detection.box_matcher import BoxMatcher


class ArgmaxBoxMatcherTest(tf.test.TestCase):
class BoxMatcherTest(tf.test.TestCase):
def test_box_matcher_invalid_length(self):
fg_threshold = 0.5
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0

with self.assertRaisesRegex(ValueError, "must be len"):
_ = ArgmaxBoxMatcher(
_ = BoxMatcher(
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
match_values=[-3, -2, -1],
)
Expand All @@ -35,7 +35,7 @@ def test_box_matcher_unsorted_thresholds(self):
bg_thresh_lo = 0.0

with self.assertRaisesRegex(ValueError, "must be sorted"):
_ = ArgmaxBoxMatcher(
_ = BoxMatcher(
thresholds=[bg_thresh_hi, bg_thresh_lo, fg_threshold],
match_values=[-3, -2, -1, 1],
)
Expand All @@ -47,7 +47,7 @@ def test_box_matcher_unbatched(self):
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0

matcher = ArgmaxBoxMatcher(
matcher = BoxMatcher(
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
match_values=[-3, -2, -1, 1],
)
Expand All @@ -67,7 +67,7 @@ def test_box_matcher_batched(self):
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0

matcher = ArgmaxBoxMatcher(
matcher = BoxMatcher(
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
match_values=[-3, -2, -1, 1],
)
Expand All @@ -90,7 +90,7 @@ def test_box_matcher_force_match(self):
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0

matcher = ArgmaxBoxMatcher(
matcher = BoxMatcher(
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
match_values=[-3, -2, -1, 1],
force_match_for_each_col=True,
Expand All @@ -113,7 +113,7 @@ def test_box_matcher_empty_gt_boxes(self):
bg_thresh_hi = 0.2
bg_thresh_lo = 0.0

matcher = ArgmaxBoxMatcher(
matcher = BoxMatcher(
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
match_values=[-3, -2, -1, 1],
)
Expand Down
6 changes: 3 additions & 3 deletions keras_cv/layers/object_detection/retina_net_label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from tensorflow.keras import layers

from keras_cv import bounding_box
from keras_cv.ops import box_matcher
from keras_cv.ops import target_gather
from keras_cv.layers.object_detection import box_matcher
from keras_cv.layers.object_detection import target_gather


class RetinaNetLabelEncoder(layers.Layer):
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
)
self.positive_threshold = positive_threshold
self.negative_threshold = negative_threshold
self.box_matcher = box_matcher.ArgmaxBoxMatcher(
self.box_matcher = box_matcher.BoxMatcher(
thresholds=[negative_threshold, positive_threshold],
match_values=[-1, -2, 1],
force_match_for_each_col=False,
Expand Down
12 changes: 6 additions & 6 deletions keras_cv/layers/object_detection/roi_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

from keras_cv import bounding_box
from keras_cv.bounding_box import iou
from keras_cv.ops import box_matcher
from keras_cv.ops import sampling
from keras_cv.ops import target_gather
from keras_cv.layers.object_detection import box_matcher
from keras_cv.layers.object_detection import sampling
from keras_cv.layers.object_detection import target_gather


@tf.keras.utils.register_keras_serializable(package="keras_cv")
Expand All @@ -44,7 +44,7 @@ class _ROISampler(tf.keras.layers.Layer):
bounding_box_format: The format of bounding boxes to generate. Refer
[to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/)
for more details on supported bounding box formats.
roi_matcher: a `ArgmaxBoxMatcher` object that matches proposals
roi_matcher: a `BoxMatcher` object that matches proposals
with ground truth boxes. the positive match must be 1 and negative match must be -1.
Such assumption is not being validated here.
positive_fraction: the positive ratio w.r.t `num_sampled_rois`. Defaults to 0.25.
Expand All @@ -59,7 +59,7 @@ class _ROISampler(tf.keras.layers.Layer):
def __init__(
self,
bounding_box_format: str,
roi_matcher: box_matcher.ArgmaxBoxMatcher,
roi_matcher: box_matcher.BoxMatcher,
positive_fraction: float = 0.25,
background_class: int = 0,
num_sampled_rois: int = 256,
Expand Down Expand Up @@ -205,5 +205,5 @@ def get_config(self):
@classmethod
def from_config(cls, config, custom_objects=None):
roi_matcher_config = config.pop("roi_matcher")
roi_matcher = box_matcher.ArgmaxBoxMatcher(**roi_matcher_config)
roi_matcher = box_matcher.BoxMatcher(**roi_matcher_config)
return cls(roi_matcher=roi_matcher, **config)
16 changes: 8 additions & 8 deletions keras_cv/layers/object_detection/roi_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

import tensorflow as tf

from keras_cv.layers.object_detection.box_matcher import BoxMatcher
from keras_cv.layers.object_detection.roi_sampler import _ROISampler
from keras_cv.ops.box_matcher import ArgmaxBoxMatcher


class ROISamplerTest(tf.test.TestCase):
def test_roi_sampler(self):
box_matcher = ArgmaxBoxMatcher(thresholds=[0.3], match_values=[-1, 1])
box_matcher = BoxMatcher(thresholds=[0.3], match_values=[-1, 1])
roi_sampler = _ROISampler(
bounding_box_format="xyxy",
roi_matcher=box_matcher,
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_roi_sampler(self):
)

def test_roi_sampler_small_threshold(self):
box_matcher = ArgmaxBoxMatcher(thresholds=[0.1], match_values=[-1, 1])
box_matcher = BoxMatcher(thresholds=[0.1], match_values=[-1, 1])
roi_sampler = _ROISampler(
bounding_box_format="xyxy",
roi_matcher=box_matcher,
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_roi_sampler_small_threshold(self):

def test_roi_sampler_large_threshold(self):
# the 2nd roi and 2nd gt box has IOU of 0.923, setting positive_threshold to 0.95 to ignore it
box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1])
box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1])
roi_sampler = _ROISampler(
bounding_box_format="xyxy",
roi_matcher=box_matcher,
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_roi_sampler_large_threshold(self):

def test_roi_sampler_large_threshold_custom_bg_class(self):
# the 2nd roi and 2nd gt box has IOU of 0.923, setting positive_threshold to 0.95 to ignore it
box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1])
box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1])
roi_sampler = _ROISampler(
bounding_box_format="xyxy",
roi_matcher=box_matcher,
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_roi_sampler_large_threshold_custom_bg_class(self):

def test_roi_sampler_large_threshold_append_gt_boxes(self):
# the 2nd roi and 2nd gt box has IOU of 0.923, setting positive_threshold to 0.95 to ignore it
box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1])
box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1])
roi_sampler = _ROISampler(
bounding_box_format="xyxy",
roi_matcher=box_matcher,
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_roi_sampler_large_threshold_append_gt_boxes(self):
self.assertAllGreaterEqual(tf.reduce_min(sampled_gt_classes), 0)

def test_roi_sampler_large_num_sampled_rois(self):
box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1])
box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1])
roi_sampler = _ROISampler(
bounding_box_format="xyxy",
roi_matcher=box_matcher,
Expand All @@ -227,7 +227,7 @@ def test_roi_sampler_large_num_sampled_rois(self):
_, _, _ = roi_sampler(rois, gt_boxes, gt_classes)

def test_serialization(self):
box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1])
box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1])
roi_sampler = _ROISampler(
bounding_box_format="xyxy",
roi_matcher=box_matcher,
Expand Down
8 changes: 4 additions & 4 deletions keras_cv/layers/object_detection/rpn_label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

from keras_cv import bounding_box
from keras_cv.bounding_box import iou
from keras_cv.ops import box_matcher
from keras_cv.ops import sampling
from keras_cv.ops import target_gather
from keras_cv.layers.object_detection import box_matcher
from keras_cv.layers.object_detection import sampling
from keras_cv.layers.object_detection import target_gather


@tf.keras.utils.register_keras_serializable(package="keras_cv")
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
self.negative_threshold = negative_threshold
self.samples_per_image = samples_per_image
self.positive_fraction = positive_fraction
self.box_matcher = box_matcher.ArgmaxBoxMatcher(
self.box_matcher = box_matcher.BoxMatcher(
thresholds=[negative_threshold, positive_threshold],
match_values=[-1, -2, 1],
force_match_for_each_col=False,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import tensorflow as tf

from keras_cv.ops.sampling import balanced_sample
from keras_cv.layers.object_detection.sampling import balanced_sample


class BalancedSamplingTest(tf.test.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import tensorflow as tf


def _target_gather(
targets: tf.Tensor,
indices: tf.Tensor,
mask: Optional[tf.Tensor] = None,
mask_val: Optional[float] = 0.0,
mask=None,
mask_val=0.0,
):
"""A utility function wrapping tf.gather, which deals with:
1) both batched and unbatched `targets`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import tensorflow as tf

from keras_cv.ops.target_gather import _target_gather
from keras_cv.layers.object_detection.target_gather import _target_gather


class TargetGatherTest(tf.test.TestCase):
Expand Down
6 changes: 2 additions & 4 deletions keras_cv/models/object_detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
from keras_cv.bounding_box.converters import _decode_deltas_to_boxes
from keras_cv.bounding_box.utils import _clip_boxes
from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator
from keras_cv.layers.object_detection.box_matcher import BoxMatcher
from keras_cv.layers.object_detection.roi_align import _ROIAligner
from keras_cv.layers.object_detection.roi_generator import ROIGenerator
from keras_cv.layers.object_detection.roi_sampler import _ROISampler
from keras_cv.layers.object_detection.rpn_label_encoder import _RpnLabelEncoder
from keras_cv.models.object_detection import predict_utils
from keras_cv.ops.box_matcher import ArgmaxBoxMatcher

BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2]

Expand Down Expand Up @@ -298,9 +298,7 @@ def __init__(
nms_score_threshold_train=float("-inf"),
nms_score_threshold_test=float("-inf"),
)
self.box_matcher = ArgmaxBoxMatcher(
thresholds=[0.0, 0.5], match_values=[-2, -1, 1]
)
self.box_matcher = BoxMatcher(thresholds=[0.0, 0.5], match_values=[-2, -1, 1])
self.roi_sampler = _ROISampler(
bounding_box_format="yxyx",
roi_matcher=self.box_matcher,
Expand Down