Skip to content

Commit

Permalink
implement transform_bounding_boxes for random_translation
Browse files Browse the repository at this point in the history
  • Loading branch information
shashaka committed Nov 23, 2024
1 parent ceabd61 commit 6318e60
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501
clip_to_image_size,
)
from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501
convert_format,
)
from keras.src.random.seed_generator import SeedGenerator
from keras.src.utils import backend_utils


@keras_export("keras.layers.RandomTranslation")
Expand Down Expand Up @@ -166,13 +173,99 @@ def transform_images(self, images, transformation, training=True):
def transform_labels(self, labels, transformation, training=True):
return labels

def get_transformed_x_y(self, x, y, transform):
a0, a1, a2, b0, b1, b2, c0, c1 = self.backend.numpy.split(
transform, 8, axis=-1
)

k = c0 * x + c1 * y + 1
x_transformed = (a0 * x + a1 * y + a2) / k
y_transformed = (b0 * x + b1 * y + b2) / k
return x_transformed, y_transformed

def get_shifted_bbox(self, bounding_boxes, w_shift_factor, h_shift_factor):
bboxes = bounding_boxes["boxes"]
x1, x2, x3, x4 = self.backend.numpy.split(bboxes, 4, axis=-1)

w_shift_factor = self.backend.convert_to_tensor(
w_shift_factor, dtype=x1.dtype
)
h_shift_factor = self.backend.convert_to_tensor(
h_shift_factor, dtype=x1.dtype
)

if len(bboxes.shape) == 3:
w_shift_factor = self.backend.numpy.expand_dims(w_shift_factor, -1)
h_shift_factor = self.backend.numpy.expand_dims(h_shift_factor, -1)

bounding_boxes["boxes"] = self.backend.numpy.concatenate(
[
x1 - w_shift_factor,
x2 - h_shift_factor,
x3 - w_shift_factor,
x4 - h_shift_factor,
],
axis=-1,
)
return bounding_boxes

def transform_bounding_boxes(
self,
bounding_boxes,
transformation,
training=True,
):
raise NotImplementedError
if backend_utils.in_tf_graph():
self.backend.set_backend("tensorflow")

if self.data_format == "channels_first":
height_axis = -2
width_axis = -1
else:
height_axis = -3
width_axis = -2

input_height, input_width = (
transformation["input_shape"][height_axis],
transformation["input_shape"][width_axis],
)

bounding_boxes = convert_format(
bounding_boxes,
source=self.bounding_box_format,
target="xyxy",
height=input_height,
width=input_width,
)

translations = transformation["translations"]
transform = self._get_translation_matrix(translations)

w_shift_factor, h_shift_factor = self.get_transformed_x_y(
0, 0, transform
)
bounding_boxes = self.get_shifted_bbox(
bounding_boxes, w_shift_factor, h_shift_factor
)

bounding_boxes = clip_to_image_size(
bounding_boxes=bounding_boxes,
height=input_height,
width=input_width,
format="xyxy",
)

bounding_boxes = convert_format(
bounding_boxes,
source="xyxy",
target=self.bounding_box_format,
height=input_height,
width=input_width,
)

self.backend.reset()

return bounding_boxes

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
Expand Down Expand Up @@ -227,7 +320,7 @@ def get_random_transformation(self, data, training=True, seed=None):
),
dtype="float32",
)
return {"translations": translations}
return {"translations": translations, "input_shape": images_shape}

def _translate_inputs(self, inputs, transformation):
if transformation is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras.src import backend
from keras.src import layers
from keras.src import testing
from keras.src.utils import backend_utils


class RandomTranslationTest(testing.TestCase):
Expand Down Expand Up @@ -328,3 +329,117 @@ def test_tf_data_compatibility(self):
input_data = np.random.random((1, 4, 4, 3))
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(1).map(layer)
next(iter(ds)).numpy()

@parameterized.named_parameters(
(
"with_positive_shift",
[[1.0, 2.0]],
[[3.0, 3.0, 5.0, 5.0], [7.0, 6.0, 8.0, 8.0]],
),
(
"with_negative_shift",
[[-1.0, -2.0]],
[[1.0, 0.0, 3.0, 1.0], [5.0, 2.0, 7.0, 4.0]],
),
)
def test_random_flip_bounding_boxes(self, translation, expected_boxes):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
image_shape = (10, 8, 3)
else:
image_shape = (3, 10, 8)
input_image = np.random.random(image_shape)
bounding_boxes = {
"boxes": np.array(
[
[2, 1, 4, 3],
[6, 4, 8, 6],
]
),
"labels": np.array([[1, 2]]),
}
input_data = {"images": input_image, "bounding_boxes": bounding_boxes}
random_translation_layer = layers.RandomTranslation(
height_factor=0.5,
width_factor=0.5,
data_format=data_format,
seed=42,
bounding_box_format="xyxy",
)

transformation = {
"translations": backend_utils.convert_tf_tensor(
np.array(translation)
),
"input_shape": image_shape,
}
output = random_translation_layer.transform_bounding_boxes(
input_data["bounding_boxes"],
transformation=transformation,
training=True,
)

self.assertAllClose(output["boxes"], expected_boxes)

@parameterized.named_parameters(
(
"with_positive_shift",
[[1.0, 2.0]],
[[3.0, 3.0, 5.0, 5.0], [7.0, 6.0, 8.0, 8.0]],
),
(
"with_negative_shift",
[[-1.0, -2.0]],
[[1.0, 0.0, 3.0, 1.0], [5.0, 2.0, 7.0, 4.0]],
),
)
def test_random_flip_tf_data_bounding_boxes(
self, translation, expected_boxes
):
data_format = backend.config.image_data_format()
if backend.config.image_data_format() == "channels_last":
image_shape = (1, 10, 8, 3)
else:
image_shape = (1, 3, 10, 8)
input_image = np.random.random(image_shape)
bounding_boxes = {
"boxes": np.array(
[
[
[2, 1, 4, 3],
[6, 4, 8, 6],
]
]
),
"labels": np.array([[1, 2]]),
}

input_data = {"images": input_image, "bounding_boxes": bounding_boxes}

ds = tf_data.Dataset.from_tensor_slices(input_data)
random_translation_layer = layers.RandomTranslation(
height_factor=0.5,
width_factor=0.5,
data_format=data_format,
seed=42,
bounding_box_format="xyxy",
)

transformation = {
"translations": backend_utils.convert_tf_tensor(
np.array(translation)
),
"input_shape": image_shape,
}

ds = ds.map(
lambda x: random_translation_layer.transform_bounding_boxes(
x["bounding_boxes"],
transformation=transformation,
training=True,
)
)

output = next(iter(ds))
expected_boxes = np.array(expected_boxes)
self.assertAllClose(output["boxes"], expected_boxes)

0 comments on commit 6318e60

Please sign in to comment.