Skip to content

Commit 37cd993

Browse files
authored
[1/3~] Begin cleaning up ops namespace (#1299)
* Begin cleaning up ops namespace * API start updates * Fix imports * move util
1 parent 3388639 commit 37cd993

File tree

14 files changed

+282
-49
lines changed

14 files changed

+282
-49
lines changed

keras_cv/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from keras_cv.layers.fusedmbconv import FusedMBConvBlock
2222
from keras_cv.layers.mbconv import MBConvBlock
2323
from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator
24+
from keras_cv.layers.object_detection.box_matcher import BoxMatcher
2425
from keras_cv.layers.object_detection.multi_class_non_max_suppression import (
2526
MultiClassNonMaxSuppression,
2627
)

keras_cv/ops/box_matcher.py renamed to keras_cv/layers/object_detection/box_matcher.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
"Argmax-based box matching"
16-
1714
from typing import List
1815
from typing import Tuple
1916

2017
import tensorflow as tf
2118

2219

2320
@tf.keras.utils.register_keras_serializable(package="keras_cv")
24-
class ArgmaxBoxMatcher(tf.keras.layers.Layer):
21+
class BoxMatcher(tf.keras.layers.Layer):
2522
"""Box matching logic based on argmax of highest value (e.g., IOU).
2623
2724
This class computes matches from a similarity matrix. Each row will be
@@ -69,8 +66,8 @@ class ArgmaxBoxMatcher(tf.keras.layers.Layer):
6966
Usage:
7067
7168
```python
72-
box_matcher = keras_cv.ops.ArgmaxBoxMatcher([0.3, 0.7], [-1, 0, 1])
73-
iou_metric = keras_cv.bounding_box.compute_iou(anchors, gt_boxes)
69+
box_matcher = keras_cv.layers.BoxMatcher([0.3, 0.7], [-1, 0, 1])
70+
iou_metric = keras_cv.bounding_box.compute_iou(anchors, boxes)
7471
matched_columns, matched_match_values = box_matcher(iou_metric)
7572
cls_mask = tf.less_equal(matched_match_values, 0)
7673
```
@@ -135,7 +132,7 @@ def _match_when_cols_are_empty():
135132
storing the match type indicator (e.g. positive or negative
136133
or ignored match).
137134
"""
138-
with tf.name_scope("empty_gt_boxes"):
135+
with tf.name_scope("empty_boxes"):
139136
matched_columns = tf.zeros([batch_size, num_rows], dtype=tf.int32)
140137
matched_values = -tf.ones([batch_size, num_rows], dtype=tf.int32)
141138
return matched_columns, matched_values
@@ -149,7 +146,7 @@ def _match_when_cols_are_non_empty():
149146
storing the match type indicator (e.g. positive or negative
150147
or ignored match).
151148
"""
152-
with tf.name_scope("non_empty_gt_boxes"):
149+
with tf.name_scope("non_empty_boxes"):
153150
matched_columns = tf.argmax(
154151
similarity_matrix, axis=-1, output_type=tf.int32
155152
)
@@ -207,11 +204,11 @@ def _match_when_cols_are_non_empty():
207204

208205
return matched_columns, matched_values
209206

210-
num_gt_boxes = (
207+
num_boxes = (
211208
similarity_matrix.shape.as_list()[-1] or tf.shape(similarity_matrix)[-1]
212209
)
213210
matched_columns, matched_values = tf.cond(
214-
pred=tf.greater(num_gt_boxes, 0),
211+
pred=tf.greater(num_boxes, 0),
215212
true_fn=_match_when_cols_are_non_empty,
216213
false_fn=_match_when_cols_are_empty,
217214
)

keras_cv/ops/box_matcher_test.py renamed to keras_cv/layers/object_detection/box_matcher_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@
1414

1515
import tensorflow as tf
1616

17-
from keras_cv.ops.box_matcher import ArgmaxBoxMatcher
17+
from keras_cv.layers.object_detection.box_matcher import BoxMatcher
1818

1919

20-
class ArgmaxBoxMatcherTest(tf.test.TestCase):
20+
class BoxMatcherTest(tf.test.TestCase):
2121
def test_box_matcher_invalid_length(self):
2222
fg_threshold = 0.5
2323
bg_thresh_hi = 0.2
2424
bg_thresh_lo = 0.0
2525

2626
with self.assertRaisesRegex(ValueError, "must be len"):
27-
_ = ArgmaxBoxMatcher(
27+
_ = BoxMatcher(
2828
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
2929
match_values=[-3, -2, -1],
3030
)
@@ -35,7 +35,7 @@ def test_box_matcher_unsorted_thresholds(self):
3535
bg_thresh_lo = 0.0
3636

3737
with self.assertRaisesRegex(ValueError, "must be sorted"):
38-
_ = ArgmaxBoxMatcher(
38+
_ = BoxMatcher(
3939
thresholds=[bg_thresh_hi, bg_thresh_lo, fg_threshold],
4040
match_values=[-3, -2, -1, 1],
4141
)
@@ -47,7 +47,7 @@ def test_box_matcher_unbatched(self):
4747
bg_thresh_hi = 0.2
4848
bg_thresh_lo = 0.0
4949

50-
matcher = ArgmaxBoxMatcher(
50+
matcher = BoxMatcher(
5151
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
5252
match_values=[-3, -2, -1, 1],
5353
)
@@ -67,7 +67,7 @@ def test_box_matcher_batched(self):
6767
bg_thresh_hi = 0.2
6868
bg_thresh_lo = 0.0
6969

70-
matcher = ArgmaxBoxMatcher(
70+
matcher = BoxMatcher(
7171
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
7272
match_values=[-3, -2, -1, 1],
7373
)
@@ -90,7 +90,7 @@ def test_box_matcher_force_match(self):
9090
bg_thresh_hi = 0.2
9191
bg_thresh_lo = 0.0
9292

93-
matcher = ArgmaxBoxMatcher(
93+
matcher = BoxMatcher(
9494
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
9595
match_values=[-3, -2, -1, 1],
9696
force_match_for_each_col=True,
@@ -113,7 +113,7 @@ def test_box_matcher_empty_gt_boxes(self):
113113
bg_thresh_hi = 0.2
114114
bg_thresh_lo = 0.0
115115

116-
matcher = ArgmaxBoxMatcher(
116+
matcher = BoxMatcher(
117117
thresholds=[bg_thresh_lo, bg_thresh_hi, fg_threshold],
118118
match_values=[-3, -2, -1, 1],
119119
)

keras_cv/layers/object_detection/retina_net_label_encoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from tensorflow.keras import layers
1717

1818
from keras_cv import bounding_box
19-
from keras_cv.ops import box_matcher
20-
from keras_cv.ops import target_gather
19+
from keras_cv.layers.object_detection import box_matcher
20+
from keras_cv.utils import target_gather
2121

2222

2323
class RetinaNetLabelEncoder(layers.Layer):
@@ -66,7 +66,7 @@ def __init__(
6666
)
6767
self.positive_threshold = positive_threshold
6868
self.negative_threshold = negative_threshold
69-
self.box_matcher = box_matcher.ArgmaxBoxMatcher(
69+
self.box_matcher = box_matcher.BoxMatcher(
7070
thresholds=[negative_threshold, positive_threshold],
7171
match_values=[-1, -2, 1],
7272
force_match_for_each_col=False,

keras_cv/layers/object_detection/roi_sampler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
from keras_cv import bounding_box
1818
from keras_cv.bounding_box import iou
19-
from keras_cv.ops import box_matcher
20-
from keras_cv.ops import sampling
21-
from keras_cv.ops import target_gather
19+
from keras_cv.layers.object_detection import box_matcher
20+
from keras_cv.layers.object_detection import sampling
21+
from keras_cv.layers.object_detection import target_gather
2222

2323

2424
@tf.keras.utils.register_keras_serializable(package="keras_cv")
@@ -44,7 +44,7 @@ class _ROISampler(tf.keras.layers.Layer):
4444
bounding_box_format: The format of bounding boxes to generate. Refer
4545
[to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/)
4646
for more details on supported bounding box formats.
47-
roi_matcher: a `ArgmaxBoxMatcher` object that matches proposals
47+
roi_matcher: a `BoxMatcher` object that matches proposals
4848
with ground truth boxes. the positive match must be 1 and negative match must be -1.
4949
Such assumption is not being validated here.
5050
positive_fraction: the positive ratio w.r.t `num_sampled_rois`. Defaults to 0.25.
@@ -59,7 +59,7 @@ class _ROISampler(tf.keras.layers.Layer):
5959
def __init__(
6060
self,
6161
bounding_box_format: str,
62-
roi_matcher: box_matcher.ArgmaxBoxMatcher,
62+
roi_matcher: box_matcher.BoxMatcher,
6363
positive_fraction: float = 0.25,
6464
background_class: int = 0,
6565
num_sampled_rois: int = 256,
@@ -205,5 +205,5 @@ def get_config(self):
205205
@classmethod
206206
def from_config(cls, config, custom_objects=None):
207207
roi_matcher_config = config.pop("roi_matcher")
208-
roi_matcher = box_matcher.ArgmaxBoxMatcher(**roi_matcher_config)
208+
roi_matcher = box_matcher.BoxMatcher(**roi_matcher_config)
209209
return cls(roi_matcher=roi_matcher, **config)

keras_cv/layers/object_detection/roi_sampler_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414

1515
import tensorflow as tf
1616

17+
from keras_cv.layers.object_detection.box_matcher import BoxMatcher
1718
from keras_cv.layers.object_detection.roi_sampler import _ROISampler
18-
from keras_cv.ops.box_matcher import ArgmaxBoxMatcher
1919

2020

2121
class ROISamplerTest(tf.test.TestCase):
2222
def test_roi_sampler(self):
23-
box_matcher = ArgmaxBoxMatcher(thresholds=[0.3], match_values=[-1, 1])
23+
box_matcher = BoxMatcher(thresholds=[0.3], match_values=[-1, 1])
2424
roi_sampler = _ROISampler(
2525
bounding_box_format="xyxy",
2626
roi_matcher=box_matcher,
@@ -57,7 +57,7 @@ def test_roi_sampler(self):
5757
)
5858

5959
def test_roi_sampler_small_threshold(self):
60-
box_matcher = ArgmaxBoxMatcher(thresholds=[0.1], match_values=[-1, 1])
60+
box_matcher = BoxMatcher(thresholds=[0.1], match_values=[-1, 1])
6161
roi_sampler = _ROISampler(
6262
bounding_box_format="xyxy",
6363
roi_matcher=box_matcher,
@@ -106,7 +106,7 @@ def test_roi_sampler_small_threshold(self):
106106

107107
def test_roi_sampler_large_threshold(self):
108108
# the 2nd roi and 2nd gt box has IOU of 0.923, setting positive_threshold to 0.95 to ignore it
109-
box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1])
109+
box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1])
110110
roi_sampler = _ROISampler(
111111
bounding_box_format="xyxy",
112112
roi_matcher=box_matcher,
@@ -139,7 +139,7 @@ def test_roi_sampler_large_threshold(self):
139139

140140
def test_roi_sampler_large_threshold_custom_bg_class(self):
141141
# the 2nd roi and 2nd gt box has IOU of 0.923, setting positive_threshold to 0.95 to ignore it
142-
box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1])
142+
box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1])
143143
roi_sampler = _ROISampler(
144144
bounding_box_format="xyxy",
145145
roi_matcher=box_matcher,
@@ -173,7 +173,7 @@ def test_roi_sampler_large_threshold_custom_bg_class(self):
173173

174174
def test_roi_sampler_large_threshold_append_gt_boxes(self):
175175
# the 2nd roi and 2nd gt box has IOU of 0.923, setting positive_threshold to 0.95 to ignore it
176-
box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1])
176+
box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1])
177177
roi_sampler = _ROISampler(
178178
bounding_box_format="xyxy",
179179
roi_matcher=box_matcher,
@@ -204,7 +204,7 @@ def test_roi_sampler_large_threshold_append_gt_boxes(self):
204204
self.assertAllGreaterEqual(tf.reduce_min(sampled_gt_classes), 0)
205205

206206
def test_roi_sampler_large_num_sampled_rois(self):
207-
box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1])
207+
box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1])
208208
roi_sampler = _ROISampler(
209209
bounding_box_format="xyxy",
210210
roi_matcher=box_matcher,
@@ -227,7 +227,7 @@ def test_roi_sampler_large_num_sampled_rois(self):
227227
_, _, _ = roi_sampler(rois, gt_boxes, gt_classes)
228228

229229
def test_serialization(self):
230-
box_matcher = ArgmaxBoxMatcher(thresholds=[0.95], match_values=[-1, 1])
230+
box_matcher = BoxMatcher(thresholds=[0.95], match_values=[-1, 1])
231231
roi_sampler = _ROISampler(
232232
bounding_box_format="xyxy",
233233
roi_matcher=box_matcher,

keras_cv/layers/object_detection/rpn_label_encoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
from keras_cv import bounding_box
2020
from keras_cv.bounding_box import iou
21-
from keras_cv.ops import box_matcher
22-
from keras_cv.ops import sampling
23-
from keras_cv.ops import target_gather
21+
from keras_cv.layers.object_detection import box_matcher
22+
from keras_cv.layers.object_detection import sampling
23+
from keras_cv.utils import target_gather
2424

2525

2626
@tf.keras.utils.register_keras_serializable(package="keras_cv")
@@ -74,7 +74,7 @@ def __init__(
7474
self.negative_threshold = negative_threshold
7575
self.samples_per_image = samples_per_image
7676
self.positive_fraction = positive_fraction
77-
self.box_matcher = box_matcher.ArgmaxBoxMatcher(
77+
self.box_matcher = box_matcher.BoxMatcher(
7878
thresholds=[negative_threshold, positive_threshold],
7979
match_values=[-1, -2, 1],
8080
force_match_for_each_col=False,
File renamed without changes.

keras_cv/ops/sampling_test.py renamed to keras_cv/layers/object_detection/sampling_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import tensorflow as tf
1616

17-
from keras_cv.ops.sampling import balanced_sample
17+
from keras_cv.layers.object_detection.sampling import balanced_sample
1818

1919

2020
class BalancedSamplingTest(tf.test.TestCase):

keras_cv/ops/target_gather.py renamed to keras_cv/layers/object_detection/target_gather.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional
16-
1715
import tensorflow as tf
1816

1917

2018
def _target_gather(
2119
targets: tf.Tensor,
2220
indices: tf.Tensor,
23-
mask: Optional[tf.Tensor] = None,
24-
mask_val: Optional[float] = 0.0,
21+
mask=None,
22+
mask_val=0.0,
2523
):
2624
"""A utility function wrapping tf.gather, which deals with:
2725
1) both batched and unbatched `targets`

0 commit comments

Comments
 (0)