Skip to content

Commit ab3c8f5

Browse files
authored
Add rand_augment processing layer (#20716)
* Add rand_augment init * Update rand_augment init * Add rand_augment * Add NotImplementedError * Add some test cases * Fix failed test case * Update rand_augment * Update rand_augment test * Fix random_rotation bug * Add build method to supress warning. * Add implementation for transform_bboxes
1 parent 8f04616 commit ab3c8f5

File tree

7 files changed

+361
-4
lines changed

7 files changed

+361
-4
lines changed

keras/api/_tf_keras/keras/layers/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@
152152
MaxNumBoundingBoxes,
153153
)
154154
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
155+
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
156+
RandAugment,
157+
)
155158
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
156159
RandomBrightness,
157160
)

keras/api/layers/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@
152152
MaxNumBoundingBoxes,
153153
)
154154
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
155+
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
156+
RandAugment,
157+
)
155158
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
156159
RandomBrightness,
157160
)

keras/src/layers/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@
9696
MaxNumBoundingBoxes,
9797
)
9898
from keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp
99+
from keras.src.layers.preprocessing.image_preprocessing.rand_augment import (
100+
RandAugment,
101+
)
99102
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
100103
RandomBrightness,
101104
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
import random
2+
3+
import keras.src.layers as layers
4+
from keras.src.api_export import keras_export
5+
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
6+
BaseImagePreprocessingLayer,
7+
)
8+
from keras.src.random import SeedGenerator
9+
from keras.src.utils import backend_utils
10+
11+
12+
@keras_export("keras.layers.RandAugment")
13+
class RandAugment(BaseImagePreprocessingLayer):
14+
"""RandAugment performs the Rand Augment operation on input images.
15+
16+
This layer can be thought of as an all-in-one image augmentation layer. The
17+
policy implemented by this layer has been benchmarked extensively and is
18+
effective on a wide variety of datasets.
19+
20+
References:
21+
- [RandAugment](https://arxiv.org/abs/1909.13719)
22+
23+
Args:
24+
value_range: The range of values the input image can take.
25+
Default is `(0, 255)`. Typically, this would be `(0, 1)`
26+
for normalized images or `(0, 255)` for raw images.
27+
num_ops: The number of augmentation operations to apply sequentially
28+
to each image. Default is 2.
29+
factor: The strength of the augmentation as a normalized value
30+
between 0 and 1. Default is 0.5.
31+
interpolation: The interpolation method to use for resizing operations.
32+
Options include `nearest`, `bilinear`. Default is `bilinear`.
33+
seed: Integer. Used to create a random seed.
34+
35+
"""
36+
37+
_USE_BASE_FACTOR = False
38+
_FACTOR_BOUNDS = (0, 1)
39+
40+
_AUGMENT_LAYERS = [
41+
"random_shear",
42+
"random_translation",
43+
"random_rotation",
44+
"random_brightness",
45+
"random_color_degeneration",
46+
"random_contrast",
47+
"random_sharpness",
48+
"random_posterization",
49+
"solarization",
50+
"auto_contrast",
51+
"equalization",
52+
]
53+
54+
def __init__(
55+
self,
56+
value_range=(0, 255),
57+
num_ops=2,
58+
factor=0.5,
59+
interpolation="bilinear",
60+
seed=None,
61+
data_format=None,
62+
**kwargs,
63+
):
64+
super().__init__(data_format=data_format, **kwargs)
65+
66+
self.value_range = value_range
67+
self.num_ops = num_ops
68+
self._set_factor(factor)
69+
self.interpolation = interpolation
70+
self.seed = seed
71+
self.generator = SeedGenerator(seed)
72+
73+
self.random_shear = layers.RandomShear(
74+
x_factor=self.factor,
75+
y_factor=self.factor,
76+
interpolation=interpolation,
77+
seed=self.seed,
78+
data_format=data_format,
79+
**kwargs,
80+
)
81+
82+
self.random_translation = layers.RandomTranslation(
83+
height_factor=self.factor,
84+
width_factor=self.factor,
85+
interpolation=interpolation,
86+
seed=self.seed,
87+
data_format=data_format,
88+
**kwargs,
89+
)
90+
91+
self.random_rotation = layers.RandomRotation(
92+
factor=self.factor,
93+
interpolation=interpolation,
94+
seed=self.seed,
95+
data_format=data_format,
96+
**kwargs,
97+
)
98+
99+
self.random_brightness = layers.RandomBrightness(
100+
factor=self.factor,
101+
value_range=self.value_range,
102+
seed=self.seed,
103+
data_format=data_format,
104+
**kwargs,
105+
)
106+
107+
self.random_color_degeneration = layers.RandomColorDegeneration(
108+
factor=self.factor,
109+
value_range=self.value_range,
110+
seed=self.seed,
111+
data_format=data_format,
112+
**kwargs,
113+
)
114+
115+
self.random_contrast = layers.RandomContrast(
116+
factor=self.factor,
117+
value_range=self.value_range,
118+
seed=self.seed,
119+
data_format=data_format,
120+
**kwargs,
121+
)
122+
123+
self.random_sharpness = layers.RandomSharpness(
124+
factor=self.factor,
125+
value_range=self.value_range,
126+
seed=self.seed,
127+
data_format=data_format,
128+
**kwargs,
129+
)
130+
131+
self.solarization = layers.Solarization(
132+
addition_factor=self.factor,
133+
threshold_factor=self.factor,
134+
value_range=self.value_range,
135+
seed=self.seed,
136+
data_format=data_format,
137+
**kwargs,
138+
)
139+
140+
self.random_posterization = layers.RandomPosterization(
141+
factor=max(1, int(8 * self.factor[1])),
142+
value_range=self.value_range,
143+
seed=self.seed,
144+
data_format=data_format,
145+
**kwargs,
146+
)
147+
148+
self.auto_contrast = layers.AutoContrast(
149+
value_range=self.value_range, data_format=data_format, **kwargs
150+
)
151+
152+
self.equalization = layers.Equalization(
153+
value_range=self.value_range, data_format=data_format, **kwargs
154+
)
155+
156+
def build(self, input_shape):
157+
for layer_name in self._AUGMENT_LAYERS:
158+
augmentation_layer = getattr(self, layer_name)
159+
augmentation_layer.build(input_shape)
160+
161+
def get_random_transformation(self, data, training=True, seed=None):
162+
if not training:
163+
return None
164+
165+
if backend_utils.in_tf_graph():
166+
self.backend.set_backend("tensorflow")
167+
168+
for layer_name in self._AUGMENT_LAYERS:
169+
augmentation_layer = getattr(self, layer_name)
170+
augmentation_layer.backend.set_backend("tensorflow")
171+
172+
transformation = {}
173+
random.shuffle(self._AUGMENT_LAYERS)
174+
for layer_name in self._AUGMENT_LAYERS[: self.num_ops]:
175+
augmentation_layer = getattr(self, layer_name)
176+
transformation[layer_name] = (
177+
augmentation_layer.get_random_transformation(
178+
data,
179+
training=training,
180+
seed=self._get_seed_generator(self.backend._backend),
181+
)
182+
)
183+
184+
return transformation
185+
186+
def transform_images(self, images, transformation, training=True):
187+
if training:
188+
images = self.backend.cast(images, self.compute_dtype)
189+
190+
for layer_name, transformation_value in transformation.items():
191+
augmentation_layer = getattr(self, layer_name)
192+
images = augmentation_layer.transform_images(
193+
images, transformation_value
194+
)
195+
196+
images = self.backend.cast(images, self.compute_dtype)
197+
return images
198+
199+
def transform_labels(self, labels, transformation, training=True):
200+
return labels
201+
202+
def transform_bounding_boxes(
203+
self,
204+
bounding_boxes,
205+
transformation,
206+
training=True,
207+
):
208+
if training:
209+
for layer_name, transformation_value in transformation.items():
210+
augmentation_layer = getattr(self, layer_name)
211+
bounding_boxes = augmentation_layer.transform_bounding_boxes(
212+
bounding_boxes, transformation_value, training=training
213+
)
214+
return bounding_boxes
215+
216+
def transform_segmentation_masks(
217+
self, segmentation_masks, transformation, training=True
218+
):
219+
return self.transform_images(
220+
segmentation_masks, transformation, training=training
221+
)
222+
223+
def compute_output_shape(self, input_shape):
224+
return input_shape
225+
226+
def get_config(self):
227+
config = {
228+
"value_range": self.value_range,
229+
"num_ops": self.num_ops,
230+
"factor": self.factor,
231+
"interpolation": self.interpolation,
232+
"seed": self.seed,
233+
}
234+
base_config = super().get_config()
235+
return {**base_config, **config}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import numpy as np
2+
import pytest
3+
from tensorflow import data as tf_data
4+
5+
from keras.src import backend
6+
from keras.src import layers
7+
from keras.src import testing
8+
9+
10+
class RandAugmentTest(testing.TestCase):
11+
@pytest.mark.requires_trainable_backend
12+
def test_layer(self):
13+
self.run_layer_test(
14+
layers.RandAugment,
15+
init_kwargs={
16+
"value_range": (0, 255),
17+
"num_ops": 2,
18+
"factor": 1,
19+
"interpolation": "nearest",
20+
"seed": 1,
21+
"data_format": "channels_last",
22+
},
23+
input_shape=(8, 3, 4, 3),
24+
supports_masking=False,
25+
expected_output_shape=(8, 3, 4, 3),
26+
)
27+
28+
def test_rand_augment_inference(self):
29+
seed = 3481
30+
layer = layers.RandAugment()
31+
32+
np.random.seed(seed)
33+
inputs = np.random.randint(0, 255, size=(224, 224, 3))
34+
output = layer(inputs, training=False)
35+
self.assertAllClose(inputs, output)
36+
37+
def test_rand_augment_basic(self):
38+
data_format = backend.config.image_data_format()
39+
if data_format == "channels_last":
40+
input_data = np.random.random((2, 8, 8, 3))
41+
else:
42+
input_data = np.random.random((2, 3, 8, 8))
43+
layer = layers.RandAugment(data_format=data_format)
44+
45+
augmented_image = layer(input_data)
46+
self.assertEqual(augmented_image.shape, input_data.shape)
47+
48+
def test_rand_augment_no_operations(self):
49+
data_format = backend.config.image_data_format()
50+
if data_format == "channels_last":
51+
input_data = np.random.random((2, 8, 8, 3))
52+
else:
53+
input_data = np.random.random((2, 3, 8, 8))
54+
layer = layers.RandAugment(num_ops=0, data_format=data_format)
55+
56+
augmented_image = layer(input_data)
57+
self.assertAllClose(
58+
backend.convert_to_numpy(augmented_image), input_data
59+
)
60+
61+
def test_random_augment_randomness(self):
62+
data_format = backend.config.image_data_format()
63+
if data_format == "channels_last":
64+
input_data = np.random.random((2, 8, 8, 3))
65+
else:
66+
input_data = np.random.random((2, 3, 8, 8))
67+
68+
layer = layers.RandAugment(num_ops=11, data_format=data_format)
69+
augmented_image = layer(input_data)
70+
71+
self.assertNotAllClose(
72+
backend.convert_to_numpy(augmented_image), input_data
73+
)
74+
75+
def test_tf_data_compatibility(self):
76+
data_format = backend.config.image_data_format()
77+
if data_format == "channels_last":
78+
input_data = np.random.random((2, 8, 8, 3))
79+
else:
80+
input_data = np.random.random((2, 3, 8, 8))
81+
layer = layers.RandAugment(data_format=data_format)
82+
83+
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
84+
for output in ds.take(1):
85+
output.numpy()
86+
87+
def test_rand_augment_tf_data_bounding_boxes(self):
88+
data_format = backend.config.image_data_format()
89+
if data_format == "channels_last":
90+
image_shape = (1, 10, 8, 3)
91+
else:
92+
image_shape = (1, 3, 10, 8)
93+
input_image = np.random.random(image_shape)
94+
bounding_boxes = {
95+
"boxes": np.array(
96+
[
97+
[
98+
[2, 1, 4, 3],
99+
[6, 4, 8, 6],
100+
]
101+
]
102+
),
103+
"labels": np.array([[1, 2]]),
104+
}
105+
106+
input_data = {"images": input_image, "bounding_boxes": bounding_boxes}
107+
108+
ds = tf_data.Dataset.from_tensor_slices(input_data)
109+
layer = layers.RandAugment(
110+
data_format=data_format,
111+
seed=42,
112+
bounding_box_format="xyxy",
113+
)
114+
ds.map(layer)

keras/src/layers/preprocessing/image_preprocessing/random_brightness_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_correctness(self):
3434
seed = 2390
3535

3636
# Always scale up, but randomly between 0 ~ 255
37-
layer = layers.RandomBrightness([0, 1.0])
37+
layer = layers.RandomBrightness([0.1, 1.0])
3838
np.random.seed(seed)
3939
inputs = np.random.randint(0, 255, size=(224, 224, 3))
4040
output = backend.convert_to_numpy(layer(inputs))
@@ -44,7 +44,7 @@ def test_correctness(self):
4444
self.assertTrue(np.mean(diff) > 0)
4545

4646
# Always scale down, but randomly between 0 ~ 255
47-
layer = layers.RandomBrightness([-1.0, 0.0])
47+
layer = layers.RandomBrightness([-1.0, -0.1])
4848
np.random.seed(seed)
4949
inputs = np.random.randint(0, 255, size=(224, 224, 3))
5050
output = backend.convert_to_numpy(layer(inputs))

0 commit comments

Comments
 (0)