@@ -479,12 +479,14 @@ def test_mixup_and_cutmix_smoothes_labels_with_videos(self):
479
479
self .assertAllGreaterEqual (aug_labels , label_smoothing / num_classes -
480
480
1e4 ) # With tolerance
481
481
482
- def test_mixup_changes_video (self ):
482
+ @parameterized .product (num_channels = [3 , 4 ])
483
+ def test_mixup_changes_video (self , num_channels : int ):
483
484
batch_size = 12
484
485
num_classes = 1000
485
486
label_smoothing = 0.1
486
487
487
- images = tf .random .normal ((batch_size , 8 , 224 , 224 , 3 ), dtype = tf .float32 )
488
+ images = tf .random .normal (
489
+ (batch_size , 8 , 224 , 224 , num_channels ), dtype = tf .float32 )
488
490
labels = tf .range (batch_size )
489
491
augmenter = augment .MixupAndCutmix (
490
492
mixup_alpha = 1. , cutmix_alpha = 0. , num_classes = num_classes )
@@ -500,12 +502,14 @@ def test_mixup_changes_video(self):
500
502
1e4 ) # With tolerance
501
503
self .assertFalse (tf .math .reduce_all (images == aug_images ))
502
504
503
- def test_cutmix_changes_video (self ):
505
+ @parameterized .product (num_channels = [3 , 4 ])
506
+ def test_cutmix_changes_video (self , num_channels : int ):
504
507
batch_size = 12
505
508
num_classes = 1000
506
509
label_smoothing = 0.1
507
510
508
- images = tf .random .normal ((batch_size , 8 , 224 , 224 , 3 ), dtype = tf .float32 )
511
+ images = tf .random .normal (
512
+ (batch_size , 8 , 224 , 224 , num_channels ), dtype = tf .float32 )
509
513
labels = tf .range (batch_size )
510
514
augmenter = augment .MixupAndCutmix (
511
515
mixup_alpha = 0. , cutmix_alpha = 1. , num_classes = num_classes )
0 commit comments