Skip to content

Commit 487bebc

Browse files
No public description
PiperOrigin-RevId: 609861906
1 parent 926b408 commit 487bebc

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

official/vision/ops/augment.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,7 @@ def _fill_rectangle_video(image,
665665
image_time = tf.shape(image)[0]
666666
image_height = tf.shape(image)[1]
667667
image_width = tf.shape(image)[2]
668+
image_channels = tf.shape(image)[3]
668669

669670
lower_pad = tf.maximum(0, center_height - half_height)
670671
upper_pad = tf.maximum(0, image_height - center_height - half_height)
@@ -681,7 +682,7 @@ def _fill_rectangle_video(image,
681682
padding_dims,
682683
constant_values=1)
683684
mask = tf.expand_dims(mask, -1)
684-
mask = tf.tile(mask, [1, 1, 1, 3])
685+
mask = tf.tile(mask, [1, 1, 1, image_channels])
685686

686687
if replace is None:
687688
fill = tf.random.normal(tf.shape(image), dtype=image.dtype)

official/vision/ops/augment_test.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -479,12 +479,14 @@ def test_mixup_and_cutmix_smoothes_labels_with_videos(self):
479479
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
480480
1e4) # With tolerance
481481

482-
def test_mixup_changes_video(self):
482+
@parameterized.product(num_channels=[3, 4])
483+
def test_mixup_changes_video(self, num_channels: int):
483484
batch_size = 12
484485
num_classes = 1000
485486
label_smoothing = 0.1
486487

487-
images = tf.random.normal((batch_size, 8, 224, 224, 3), dtype=tf.float32)
488+
images = tf.random.normal(
489+
(batch_size, 8, 224, 224, num_channels), dtype=tf.float32)
488490
labels = tf.range(batch_size)
489491
augmenter = augment.MixupAndCutmix(
490492
mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes)
@@ -500,12 +502,14 @@ def test_mixup_changes_video(self):
500502
1e4) # With tolerance
501503
self.assertFalse(tf.math.reduce_all(images == aug_images))
502504

503-
def test_cutmix_changes_video(self):
505+
@parameterized.product(num_channels=[3, 4])
506+
def test_cutmix_changes_video(self, num_channels: int):
504507
batch_size = 12
505508
num_classes = 1000
506509
label_smoothing = 0.1
507510

508-
images = tf.random.normal((batch_size, 8, 224, 224, 3), dtype=tf.float32)
511+
images = tf.random.normal(
512+
(batch_size, 8, 224, 224, num_channels), dtype=tf.float32)
509513
labels = tf.range(batch_size)
510514
augmenter = augment.MixupAndCutmix(
511515
mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes)

0 commit comments

Comments
 (0)