-
Notifications
You must be signed in to change notification settings - Fork 45.8k
/
augment.py
2922 lines (2448 loc) · 106 KB
/
augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Augmentation policies for enhanced image/video preprocessing.
AutoAugment Reference:
- AutoAugment Reference: https://arxiv.org/abs/1805.09501
- AutoAugment for Object Detection Reference: https://arxiv.org/abs/1906.11172
RandAugment Reference: https://arxiv.org/abs/1909.13719
RandomErasing Reference: https://arxiv.org/abs/1708.04896
MixupAndCutmix:
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
RandomErasing, Mixup and Cutmix are inspired by
https://github.com/rwightman/pytorch-image-models
SSDRandCrop Reference:
- Liu et al., SSD: Single shot multibox detector:
https://arxiv.org/abs/1512.02325
- Implementation from TF Object Detection API:
https://github.com/tensorflow/models/
"""
from collections.abc import Sequence
import inspect
import math
from typing import Any, Iterable, List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf, tf_keras
from official.vision.configs import common as configs
from official.vision.ops import box_ops
# This signifies the max integer that the controller RNN could predict for the
# augmentation scheme.
_MAX_LEVEL = 10.
def to_4d(image: tf.Tensor) -> tf.Tensor:
"""Converts an input Tensor to 4 dimensions.
4D image => [N, H, W, C] or [N, C, H, W]
3D image => [1, H, W, C] or [1, C, H, W]
2D image => [1, H, W, 1]
Args:
image: The 2/3/4D input tensor.
Returns:
A 4D image tensor.
Raises:
`TypeError` if `image` is not a 2/3/4D tensor.
"""
shape = tf.shape(image)
original_rank = tf.rank(image)
left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32)
right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32)
new_shape = tf.concat(
[
tf.ones(shape=left_pad, dtype=tf.int32),
shape,
tf.ones(shape=right_pad, dtype=tf.int32),
],
axis=0,
)
return tf.reshape(image, new_shape)
def from_4d(image: tf.Tensor, ndims: tf.Tensor) -> tf.Tensor:
"""Converts a 4D image back to `ndims` rank."""
shape = tf.shape(image)
begin = tf.cast(tf.less_equal(ndims, 3), dtype=tf.int32)
end = 4 - tf.cast(tf.equal(ndims, 2), dtype=tf.int32)
new_shape = shape[begin:end]
return tf.reshape(image, new_shape)
def _pad(
image: tf.Tensor,
filter_shape: Union[List[int], Tuple[int, ...]],
mode: str = 'CONSTANT',
constant_values: Union[int, tf.Tensor] = 0,
) -> tf.Tensor:
"""Explicitly pads a 4-D image.
Equivalent to the implicit padding method offered in `tf.nn.conv2d` and
`tf.nn.depthwise_conv2d`, but supports non-zero, reflect and symmetric
padding mode. For the even-sized filter, it pads one more value to the
right or the bottom side.
Args:
image: A 4-D `Tensor` of shape `[batch_size, height, width, channels]`.
filter_shape: A `tuple`/`list` of 2 integers, specifying the height and
width of the 2-D filter.
mode: A `string`, one of "REFLECT", "CONSTANT", or "SYMMETRIC". The type of
padding algorithm to use, which is compatible with `mode` argument in
`tf.pad`. For more details, please refer to
https://www.tensorflow.org/api_docs/python/tf/pad.
constant_values: A `scalar`, the pad value to use in "CONSTANT" padding
mode.
Returns:
A padded image.
"""
if mode.upper() not in {'REFLECT', 'CONSTANT', 'SYMMETRIC'}:
raise ValueError(
'padding should be one of "REFLECT", "CONSTANT", or "SYMMETRIC".'
)
constant_values = tf.convert_to_tensor(constant_values, image.dtype)
filter_height, filter_width = filter_shape
pad_top = (filter_height - 1) // 2
pad_bottom = filter_height - 1 - pad_top
pad_left = (filter_width - 1) // 2
pad_right = filter_width - 1 - pad_left
paddings = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
return tf.pad(image, paddings, mode=mode, constant_values=constant_values)
def _get_gaussian_kernel(sigma, filter_shape):
"""Computes 1D Gaussian kernel."""
sigma = tf.convert_to_tensor(sigma)
x = tf.range(-filter_shape // 2 + 1, filter_shape // 2 + 1)
x = tf.cast(x**2, sigma.dtype)
x = tf.nn.softmax(-x / (2.0 * (sigma**2)))
return x
def _get_gaussian_kernel_2d(gaussian_filter_x, gaussian_filter_y):
"""Computes 2D Gaussian kernel given 1D kernels."""
gaussian_kernel = tf.matmul(gaussian_filter_x, gaussian_filter_y)
return gaussian_kernel
def _normalize_tuple(value, n, name):
"""Transforms an integer or iterable of integers into an integer tuple.
Args:
value: The value to validate and convert. Could an int, or any iterable of
ints.
n: The size of the tuple to be returned.
name: The name of the argument being validated, e.g. "strides" or
"kernel_size". This is only used to format error messages.
Returns:
A tuple of n integers.
Raises:
ValueError: If something else than an int/long or iterable thereof was
passed.
"""
if isinstance(value, int):
return (value,) * n
else:
try:
value_tuple = tuple(value)
except TypeError as exc:
raise TypeError(
f'The {name} argument must be a tuple of {n} integers. '
f'Received: {value}'
) from exc
if len(value_tuple) != n:
raise ValueError(
f'The {name} argument must be a tuple of {n} integers. '
f'Received: {value}'
)
for single_value in value_tuple:
try:
int(single_value)
except (ValueError, TypeError) as exc:
raise ValueError(
f'The {name} argument must be a tuple of {n} integers. Received:'
f' {value} including element {single_value} of type'
f' {type(single_value)}.'
) from exc
return value_tuple
def gaussian_filter2d(
image: tf.Tensor,
filter_shape: Union[List[int], Tuple[int, ...], int],
sigma: Union[List[float], Tuple[float, float], float] = 1.0,
padding: str = 'REFLECT',
constant_values: Union[int, tf.Tensor] = 0,
name: Optional[str] = None,
) -> tf.Tensor:
"""Performs Gaussian blur on image(s).
Args:
image: Either a 2-D `Tensor` of shape `[height, width]`, a 3-D `Tensor` of
shape `[height, width, channels]`, or a 4-D `Tensor` of shape
`[batch_size, height, width, channels]`.
filter_shape: An `integer` or `tuple`/`list` of 2 integers, specifying the
height and width of the 2-D gaussian filter. Can be a single integer to
specify the same value for all spatial dimensions.
sigma: A `float` or `tuple`/`list` of 2 floats, specifying the standard
deviation in x and y direction the 2-D gaussian filter. Can be a single
float to specify the same value for all spatial dimensions.
padding: A `string`, one of "REFLECT", "CONSTANT", or "SYMMETRIC". The type
of padding algorithm to use, which is compatible with `mode` argument in
`tf.pad`. For more details, please refer to
https://www.tensorflow.org/api_docs/python/tf/pad.
constant_values: A `scalar`, the pad value to use in "CONSTANT" padding
mode.
name: A name for this operation (optional).
Returns:
2-D, 3-D or 4-D `Tensor` of the same dtype as input.
Raises:
ValueError: If `image` is not 2, 3 or 4-dimensional,
if `padding` is other than "REFLECT", "CONSTANT" or "SYMMETRIC",
if `filter_shape` is invalid,
or if `sigma` is invalid.
"""
with tf.name_scope(name or 'gaussian_filter2d'):
if isinstance(sigma, (list, tuple)):
if len(sigma) != 2:
raise ValueError('sigma should be a float or a tuple/list of 2 floats')
else:
sigma = (sigma,) * 2
if any(s < 0 for s in sigma):
raise ValueError('sigma should be greater than or equal to 0.')
image = tf.convert_to_tensor(image, name='image')
sigma = tf.convert_to_tensor(sigma, name='sigma')
original_ndims = tf.rank(image)
image = to_4d(image)
# Keep the precision if it's float;
# otherwise, convert to float32 for computing.
orig_dtype = image.dtype
if not image.dtype.is_floating:
image = tf.cast(image, tf.float32)
channels = tf.shape(image)[3]
filter_shape = _normalize_tuple(filter_shape, 2, 'filter_shape')
sigma = tf.cast(sigma, image.dtype)
gaussian_kernel_x = _get_gaussian_kernel(sigma[1], filter_shape[1])
gaussian_kernel_x = gaussian_kernel_x[tf.newaxis, :]
gaussian_kernel_y = _get_gaussian_kernel(sigma[0], filter_shape[0])
gaussian_kernel_y = gaussian_kernel_y[:, tf.newaxis]
gaussian_kernel_2d = _get_gaussian_kernel_2d(
gaussian_kernel_y, gaussian_kernel_x
)
gaussian_kernel_2d = gaussian_kernel_2d[:, :, tf.newaxis, tf.newaxis]
gaussian_kernel_2d = tf.tile(gaussian_kernel_2d, [1, 1, channels, 1])
image = _pad(
image, filter_shape, mode=padding, constant_values=constant_values
)
output = tf.nn.depthwise_conv2d(
input=image,
filter=gaussian_kernel_2d,
strides=(1, 1, 1, 1),
padding='VALID',
)
output = from_4d(output, original_ndims)
return tf.cast(output, orig_dtype)
def _convert_translation_to_transform(translations: tf.Tensor) -> tf.Tensor:
"""Converts translations to a projective transform.
The translation matrix looks like this:
[[1 0 -dx]
[0 1 -dy]
[0 0 1]]
Args:
translations: The 2-element list representing [dx, dy], or a matrix of
2-element lists representing [dx dy] to translate for each image. The
shape must be static.
Returns:
The transformation matrix of shape (num_images, 8).
Raises:
`TypeError` if
- the shape of `translations` is not known or
- the shape of `translations` is not rank 1 or 2.
"""
translations = tf.convert_to_tensor(translations, dtype=tf.float32)
if translations.get_shape().ndims is None:
raise TypeError('translations rank must be statically known')
elif len(translations.get_shape()) == 1:
translations = translations[None]
elif len(translations.get_shape()) != 2:
raise TypeError('translations should have rank 1 or 2.')
num_translations = tf.shape(translations)[0]
return tf.concat(
values=[
tf.ones((num_translations, 1), tf.dtypes.float32),
tf.zeros((num_translations, 1), tf.dtypes.float32),
-translations[:, 0, None],
tf.zeros((num_translations, 1), tf.dtypes.float32),
tf.ones((num_translations, 1), tf.dtypes.float32),
-translations[:, 1, None],
tf.zeros((num_translations, 2), tf.dtypes.float32),
],
axis=1,
)
def _convert_angles_to_transform(angles: tf.Tensor, image_width: tf.Tensor,
image_height: tf.Tensor) -> tf.Tensor:
"""Converts an angle or angles to a projective transform.
Args:
angles: A scalar to rotate all images, or a vector to rotate a batch of
images. This must be a scalar.
image_width: The width of the image(s) to be transformed.
image_height: The height of the image(s) to be transformed.
Returns:
A tensor of shape (num_images, 8).
Raises:
`TypeError` if `angles` is not rank 0 or 1.
"""
angles = tf.convert_to_tensor(angles, dtype=tf.float32)
if len(angles.get_shape()) == 0: # pylint:disable=g-explicit-length-test
angles = angles[None]
elif len(angles.get_shape()) != 1:
raise TypeError('Angles should have a rank 0 or 1.')
x_offset = ((image_width - 1) -
(tf.math.cos(angles) * (image_width - 1) - tf.math.sin(angles) *
(image_height - 1))) / 2.0
y_offset = ((image_height - 1) -
(tf.math.sin(angles) * (image_width - 1) + tf.math.cos(angles) *
(image_height - 1))) / 2.0
num_angles = tf.shape(angles)[0]
return tf.concat(
values=[
tf.math.cos(angles)[:, None],
-tf.math.sin(angles)[:, None],
x_offset[:, None],
tf.math.sin(angles)[:, None],
tf.math.cos(angles)[:, None],
y_offset[:, None],
tf.zeros((num_angles, 2), tf.dtypes.float32),
],
axis=1,
)
def _apply_transform_to_images(
images,
transforms,
fill_mode='reflect',
fill_value=0.0,
interpolation='bilinear',
output_shape=None,
name=None,
):
"""Applies the given transform(s) to the image(s).
Args:
images: A tensor of shape `(num_images, num_rows, num_columns,
num_channels)` (NHWC). The rank must be statically known (the shape is
not `TensorShape(None)`).
transforms: Projective transform matrix/matrices. A vector of length 8 or
tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0, b1,
b2, c0, c1], then it maps the *output* point `(x, y)` to a transformed
*input* point `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) /
k)`, where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared
to the transform mapping input points to output points. Note that
gradients are not backpropagated into transformation parameters.
fill_mode: Points outside the boundaries of the input are filled according
to the given mode (one of `{"constant", "reflect", "wrap", "nearest"}`).
fill_value: a float represents the value to be filled outside the
boundaries when `fill_mode="constant"`.
interpolation: Interpolation mode. Supported values: `"nearest"`,
`"bilinear"`.
output_shape: Output dimension after the transform, `[height, width]`. If
`None`, output is the same size as input image.
name: The name of the op. Fill mode behavior for each valid value is as
follows
- `"reflect"`: `(d c b a | a b c d | d c b a)` The input is extended by
reflecting about the edge of the last pixel.
- `"constant"`: `(k k k k | a b c d | k k k k)` The input is extended by
filling all values beyond the edge with the same constant value k = 0.
- `"wrap"`: `(a b c d | a b c d | a b c d)` The input is extended by
wrapping around to the opposite edge.
- `"nearest"`: `(a a a a | a b c d | d d d d)` The input is extended by
the nearest pixel. Input shape: 4D tensor with shape:
`(samples, height, width, channels)`, in `"channels_last"` format.
Output shape: 4D tensor with shape: `(samples, height, width, channels)`,
in `"channels_last"` format.
Returns:
Image(s) with the same type and shape as `images`, with the given
transform(s) applied. Transformed coordinates outside of the input image
will be filled with zeros.
"""
with tf.name_scope(name or 'transform'):
if output_shape is None:
output_shape = tf.shape(images)[1:3]
if not tf.executing_eagerly():
output_shape_value = tf.get_static_value(output_shape)
if output_shape_value is not None:
output_shape = output_shape_value
output_shape = tf.convert_to_tensor(
output_shape, tf.int32, name='output_shape'
)
if not output_shape.get_shape().is_compatible_with([2]):
raise ValueError(
'output_shape must be a 1-D Tensor of 2 elements: '
'new_height, new_width, instead got '
f'output_shape={output_shape}'
)
fill_value = tf.convert_to_tensor(fill_value, tf.float32, name='fill_value')
return tf.raw_ops.ImageProjectiveTransformV3(
images=images,
output_shape=output_shape,
fill_value=fill_value,
transforms=transforms,
fill_mode=fill_mode.upper(),
interpolation=interpolation.upper(),
)
def transform(
image: tf.Tensor,
transforms: Any,
interpolation: str = 'nearest',
output_shape=None,
fill_mode: str = 'reflect',
fill_value: float = 0.0,
) -> tf.Tensor:
"""Transforms an image."""
original_ndims = tf.rank(image)
transforms = tf.convert_to_tensor(transforms, dtype=tf.float32)
if transforms.shape.rank == 1:
transforms = transforms[None]
image = to_4d(image)
image = _apply_transform_to_images(
images=image,
transforms=transforms,
interpolation=interpolation,
fill_mode=fill_mode,
fill_value=fill_value,
output_shape=output_shape,
)
return from_4d(image, original_ndims)
def translate(
image: tf.Tensor,
translations,
fill_value: float = 0.0,
fill_mode: str = 'reflect',
interpolation: str = 'nearest',
) -> tf.Tensor:
"""Translates image(s) by provided vectors.
Args:
image: An image Tensor of type uint8.
translations: A vector or matrix representing [dx dy].
fill_value: a float represents the value to be filled outside the boundaries
when `fill_mode="constant"`.
fill_mode: Points outside the boundaries of the input are filled according
to the given mode (one of `{"constant", "reflect", "wrap", "nearest"}`).
interpolation: Interpolation mode. Supported values: `"nearest"`,
`"bilinear"`.
Returns:
The translated version of the image.
"""
transforms = _convert_translation_to_transform(translations) # pytype: disable=wrong-arg-types # always-use-return-annotations
return transform(
image,
transforms=transforms,
interpolation=interpolation,
fill_value=fill_value,
fill_mode=fill_mode,
)
def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor:
"""Rotates the image by degrees either clockwise or counterclockwise.
Args:
image: An image Tensor of type uint8.
degrees: Float, a scalar angle in degrees to rotate all images by. If
degrees is positive the image will be rotated clockwise otherwise it will
be rotated counterclockwise.
Returns:
The rotated version of image.
"""
# Convert from degrees to radians.
degrees_to_radians = math.pi / 180.0
radians = tf.cast(degrees * degrees_to_radians, tf.float32)
original_ndims = tf.rank(image)
image = to_4d(image)
image_height = tf.cast(tf.shape(image)[1], tf.float32)
image_width = tf.cast(tf.shape(image)[2], tf.float32)
transforms = _convert_angles_to_transform(
angles=radians, image_width=image_width, image_height=image_height)
# In practice, we should randomize the rotation degrees by flipping
# it negatively half the time, but that's done on 'degrees' outside
# of the function.
image = transform(image, transforms=transforms)
return from_4d(image, original_ndims)
def blend(image1: tf.Tensor, image2: tf.Tensor, factor: float) -> tf.Tensor:
"""Blend image1 and image2 using 'factor'.
Factor can be above 0.0. A value of 0.0 means only image1 is used.
A value of 1.0 means only image2 is used. A value between 0.0 and
1.0 means we linearly interpolate the pixel values between the two
images. A value greater than 1.0 "extrapolates" the difference
between the two pixel values, and we clip the results to values
between 0 and 255.
Args:
image1: An image Tensor of type uint8.
image2: An image Tensor of type uint8.
factor: A floating point value above 0.0.
Returns:
A blended image Tensor of type uint8.
"""
if factor == 0.0:
return tf.convert_to_tensor(image1)
if factor == 1.0:
return tf.convert_to_tensor(image2)
image1 = tf.cast(image1, tf.float32)
image2 = tf.cast(image2, tf.float32)
difference = image2 - image1
scaled = factor * difference
# Do addition in float.
temp = tf.cast(image1, tf.float32) + scaled
# Interpolate
if factor > 0.0 and factor < 1.0:
# Interpolation means we always stay within 0 and 255.
return tf.cast(temp, tf.uint8)
# Extrapolate:
#
# We need to clip and then cast.
return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8)
def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
"""Apply cutout (https://arxiv.org/abs/1708.04552) to image.
This operation applies a (2*pad_size x 2*pad_size) mask of zeros to
a random location within `image`. The pixel values filled in will be of the
value `replace`. The location where the mask will be applied is randomly
chosen uniformly over the whole image.
Args:
image: An image Tensor of type uint8.
pad_size: Specifies how big the zero mask that will be generated is that is
applied to the image. The mask will be of size (2*pad_size x 2*pad_size).
replace: What pixel value to fill in the image in the area that has the
cutout mask applied to it.
Returns:
An image Tensor that is of type uint8.
"""
if image.shape.rank not in [3, 4]:
raise ValueError('Bad image rank: {}'.format(image.shape.rank))
if image.shape.rank == 4:
return cutout_video(image, replace=replace)
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
# Sample the center location in the image where the zero mask will be applied.
cutout_center_height = tf.random.uniform(
shape=[], minval=0, maxval=image_height, dtype=tf.int32)
cutout_center_width = tf.random.uniform(
shape=[], minval=0, maxval=image_width, dtype=tf.int32)
image = _fill_rectangle(image, cutout_center_width, cutout_center_height,
pad_size, pad_size, replace)
return image
def _fill_rectangle(image,
center_width,
center_height,
half_width,
half_height,
replace=None):
"""Fills blank area."""
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
lower_pad = tf.maximum(0, center_height - half_height)
upper_pad = tf.maximum(0, image_height - center_height - half_height)
left_pad = tf.maximum(0, center_width - half_width)
right_pad = tf.maximum(0, image_width - center_width - half_width)
cutout_shape = [
image_height - (lower_pad + upper_pad),
image_width - (left_pad + right_pad)
]
padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
mask = tf.pad(
tf.zeros(cutout_shape, dtype=image.dtype),
padding_dims,
constant_values=1)
mask = tf.expand_dims(mask, -1)
mask = tf.tile(mask, [1, 1, 3])
if replace is None:
fill = tf.random.normal(tf.shape(image), dtype=image.dtype)
elif isinstance(replace, tf.Tensor):
fill = replace
else:
fill = tf.ones_like(image, dtype=image.dtype) * replace
image = tf.where(tf.equal(mask, 0), fill, image)
return image
def _fill_rectangle_video(image,
center_width,
center_height,
half_width,
half_height,
replace=None):
"""Fills blank area for video."""
image_time = tf.shape(image)[0]
image_height = tf.shape(image)[1]
image_width = tf.shape(image)[2]
image_channels = tf.shape(image)[3]
lower_pad = tf.maximum(0, center_height - half_height)
upper_pad = tf.maximum(0, image_height - center_height - half_height)
left_pad = tf.maximum(0, center_width - half_width)
right_pad = tf.maximum(0, image_width - center_width - half_width)
cutout_shape = [
image_time, image_height - (lower_pad + upper_pad),
image_width - (left_pad + right_pad)
]
padding_dims = [[0, 0], [lower_pad, upper_pad], [left_pad, right_pad]]
mask = tf.pad(
tf.zeros(cutout_shape, dtype=image.dtype),
padding_dims,
constant_values=1)
mask = tf.expand_dims(mask, -1)
mask = tf.tile(mask, [1, 1, 1, image_channels])
if replace is None:
fill = tf.random.normal(tf.shape(image), dtype=image.dtype)
elif isinstance(replace, tf.Tensor):
fill = replace
else:
fill = tf.ones_like(image, dtype=image.dtype) * replace
image = tf.where(tf.equal(mask, 0), fill, image)
return image
def cutout_video(
video: tf.Tensor,
mask_shape: Optional[tf.Tensor] = None,
replace: int = 0,
) -> tf.Tensor:
"""Apply cutout (https://arxiv.org/abs/1708.04552) to a video.
This operation applies a random size 3D mask of zeros to a random location
within `video`. The mask is padded The pixel values filled in will be of the
value `replace`. The location where the mask will be applied is randomly
chosen uniformly over the whole video. If the size of the mask is not set,
then, it is randomly sampled uniformly from [0.25*height, 0.5*height],
[0.25*width, 0.5*width], and [1, 0.25*depth], which represent the height,
width, and number of frames of the input video tensor respectively.
Args:
video: A video Tensor of shape [T, H, W, C].
mask_shape: An optional integer tensor that specifies the depth, height and
width of the mask to cut. If it is not set, the shape is randomly sampled
as described above. The shape dimensions should be divisible by 2
otherwise they will rounded down.
replace: What pixel value to fill in the image in the area that has the
cutout mask applied to it.
Returns:
A video Tensor with cutout applied.
"""
tf.debugging.assert_shapes([
(video, ('T', 'H', 'W', 'C')),
])
video_depth = tf.shape(video)[0]
video_height = tf.shape(video)[1]
video_width = tf.shape(video)[2]
# Sample the center location in the image where the zero mask will be applied.
cutout_center_height = tf.random.uniform(
shape=[], minval=0, maxval=video_height, dtype=tf.int32
)
cutout_center_width = tf.random.uniform(
shape=[], minval=0, maxval=video_width, dtype=tf.int32
)
cutout_center_depth = tf.random.uniform(
shape=[], minval=0, maxval=video_depth, dtype=tf.int32
)
if mask_shape is not None:
pad_shape = tf.maximum(1, mask_shape // 2)
pad_size_depth, pad_size_height, pad_size_width = (
pad_shape[0],
pad_shape[1],
pad_shape[2],
)
else:
pad_size_height = tf.random.uniform(
shape=[],
minval=tf.maximum(1, tf.cast(video_height / 4, tf.int32)),
maxval=tf.maximum(2, tf.cast(video_height / 2, tf.int32)),
dtype=tf.int32,
)
pad_size_width = tf.random.uniform(
shape=[],
minval=tf.maximum(1, tf.cast(video_width / 4, tf.int32)),
maxval=tf.maximum(2, tf.cast(video_width / 2, tf.int32)),
dtype=tf.int32,
)
pad_size_depth = tf.random.uniform(
shape=[],
minval=1,
maxval=tf.maximum(2, tf.cast(video_depth / 4, tf.int32)),
dtype=tf.int32,
)
lower_pad = tf.maximum(0, cutout_center_height - pad_size_height)
upper_pad = tf.maximum(
0, video_height - cutout_center_height - pad_size_height
)
left_pad = tf.maximum(0, cutout_center_width - pad_size_width)
right_pad = tf.maximum(0, video_width - cutout_center_width - pad_size_width)
back_pad = tf.maximum(0, cutout_center_depth - pad_size_depth)
forward_pad = tf.maximum(
0, video_depth - cutout_center_depth - pad_size_depth
)
cutout_shape = [
video_depth - (back_pad + forward_pad),
video_height - (lower_pad + upper_pad),
video_width - (left_pad + right_pad),
]
padding_dims = [[back_pad, forward_pad],
[lower_pad, upper_pad],
[left_pad, right_pad]]
mask = tf.pad(
tf.zeros(cutout_shape, dtype=video.dtype), padding_dims, constant_values=1
)
mask = tf.expand_dims(mask, -1)
num_channels = tf.shape(video)[-1]
mask = tf.tile(mask, [1, 1, 1, num_channels])
video = tf.where(
tf.equal(mask, 0), tf.ones_like(video, dtype=video.dtype) * replace, video
)
return video
def gaussian_noise(
image: tf.Tensor, low: float = 0.1, high: float = 2.0) -> tf.Tensor:
"""Add Gaussian noise to image(s)."""
augmented_image = gaussian_filter2d( # pylint: disable=g-long-lambda
image, filter_shape=[3, 3], sigma=np.random.uniform(low=low, high=high)
)
return augmented_image
def solarize(image: tf.Tensor, threshold: int = 128) -> tf.Tensor:
"""Solarize the input image(s)."""
# For each pixel in the image, select the pixel
# if the value is less than the threshold.
# Otherwise, subtract 255 from the pixel.
return tf.where(image < threshold, image, 255 - image)
def solarize_add(image: tf.Tensor,
addition: int = 0,
threshold: int = 128) -> tf.Tensor:
"""Additive solarize the input image(s)."""
# For each pixel in the image less than threshold
# we add 'addition' amount to it and then clip the
# pixel value to be between 0 and 255. The value
# of 'addition' is between -128 and 128.
added_image = tf.cast(image, tf.int64) + addition
added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8)
return tf.where(image < threshold, added_image, image)
def grayscale(image: tf.Tensor) -> tf.Tensor:
"""Convert image to grayscale."""
return tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))
def color(image: tf.Tensor, factor: float) -> tf.Tensor:
"""Equivalent of PIL Color."""
degenerate = grayscale(image)
return blend(degenerate, image, factor)
def contrast(image: tf.Tensor, factor: float) -> tf.Tensor:
"""Equivalent of PIL Contrast."""
degenerate = tf.image.rgb_to_grayscale(image)
# Cast before calling tf.histogram.
degenerate = tf.cast(degenerate, tf.int32)
# Compute the grayscale histogram, then compute the mean pixel value,
# and create a constant image size of that value. Use that as the
# blending degenerate target of the original image.
hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0
degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean
degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8))
return blend(degenerate, image, factor)
def brightness(image: tf.Tensor, factor: float) -> tf.Tensor:
"""Equivalent of PIL Brightness."""
degenerate = tf.zeros_like(image)
return blend(degenerate, image, factor)
def posterize(image: tf.Tensor, bits: int) -> tf.Tensor:
"""Equivalent of PIL Posterize."""
shift = 8 - bits
return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift)
def wrapped_rotate(image: tf.Tensor, degrees: float, replace: int) -> tf.Tensor:
"""Applies rotation with wrap/unwrap."""
image = rotate(wrap(image), degrees=degrees)
return unwrap(image, replace)
def translate_x(image: tf.Tensor, pixels: int, replace: int) -> tf.Tensor:
"""Equivalent of PIL Translate in X dimension."""
image = translate(wrap(image), [-pixels, 0])
return unwrap(image, replace)
def translate_y(image: tf.Tensor, pixels: int, replace: int) -> tf.Tensor:
"""Equivalent of PIL Translate in Y dimension."""
image = translate(wrap(image), [0, -pixels])
return unwrap(image, replace)
def shear_x(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
"""Equivalent of PIL Shearing in X dimension."""
# Shear parallel to x axis is a projective transform
# with a matrix form of:
# [1 level
# 0 1].
image = transform(
image=wrap(image), transforms=[1., level, 0., 0., 1., 0., 0., 0.])
return unwrap(image, replace)
def shear_y(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
"""Equivalent of PIL Shearing in Y dimension."""
# Shear parallel to y axis is a projective transform
# with a matrix form of:
# [1 0
# level 1].
image = transform(
image=wrap(image), transforms=[1., 0., 0., level, 1., 0., 0., 0.])
return unwrap(image, replace)
def autocontrast(image: tf.Tensor) -> tf.Tensor:
"""Implements Autocontrast function from PIL using TF ops.
Args:
image: A 3D uint8 tensor.
Returns:
The image after it has had autocontrast applied to it and will be of type
uint8.
"""
def scale_channel(image: tf.Tensor) -> tf.Tensor:
"""Scale the 2D image using the autocontrast rule."""
# A possibly cheaper version can be done using cumsum/unique_with_counts
# over the histogram values, rather than iterating over the entire image.
# to compute mins and maxes.
lo = tf.cast(tf.reduce_min(image), tf.float32)
hi = tf.cast(tf.reduce_max(image), tf.float32)
# Scale the image, making the lowest value 0 and the highest value 255.
def scale_values(im):
scale = 255.0 / (hi - lo)
offset = -lo * scale
im = tf.cast(im, tf.float32) * scale + offset
im = tf.clip_by_value(im, 0.0, 255.0)
return tf.cast(im, tf.uint8)
result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image)
return result
# Assumes RGB for now. Scales each channel independently
# and then stacks the result.
s1 = scale_channel(image[..., 0])
s2 = scale_channel(image[..., 1])
s3 = scale_channel(image[..., 2])
image = tf.stack([s1, s2, s3], -1)
return image
def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor:
"""Implements Sharpness function from PIL using TF ops."""
orig_image = image
image = tf.cast(image, tf.float32)
# Make image 4D for conv operation.
image = tf.expand_dims(image, 0)
# SMOOTH PIL Kernel.
if orig_image.shape.rank == 3:
kernel = tf.constant([[1, 1, 1], [1, 5, 1], [1, 1, 1]],
dtype=tf.float32,
shape=[3, 3, 1, 1]) / 13.
# Tile across channel dimension.
kernel = tf.tile(kernel, [1, 1, 3, 1])
strides = [1, 1, 1, 1]
degenerate = tf.nn.depthwise_conv2d(
image, kernel, strides, padding='VALID', dilations=[1, 1])
elif orig_image.shape.rank == 4:
kernel = tf.constant([[1, 1, 1], [1, 5, 1], [1, 1, 1]],
dtype=tf.float32,
shape=[1, 3, 3, 1, 1]) / 13.
strides = [1, 1, 1, 1, 1]
# Run the kernel across each channel
channels = tf.split(image, 3, axis=-1)
degenerates = [
tf.nn.conv3d(channel, kernel, strides, padding='VALID',
dilations=[1, 1, 1, 1, 1])
for channel in channels
]
degenerate = tf.concat(degenerates, -1)
else:
raise ValueError('Bad image rank: {}'.format(image.shape.rank))
degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0])
# For the borders of the resulting image, fill in the values of the
# original image.
mask = tf.ones_like(degenerate)
paddings = [[0, 0]] * (orig_image.shape.rank - 3)
padded_mask = tf.pad(mask, paddings + [[1, 1], [1, 1], [0, 0]])
padded_degenerate = tf.pad(degenerate, paddings + [[1, 1], [1, 1], [0, 0]])
result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image)
# Blend the final result.
return blend(result, orig_image, factor)
def equalize(image: tf.Tensor) -> tf.Tensor: