diff --git a/example/mxnet/train_cifar100_byteps_gc.py b/example/mxnet/train_cifar100_byteps_gc.py index 4273fcb8c..5ffa7298f 100644 --- a/example/mxnet/train_cifar100_byteps_gc.py +++ b/example/mxnet/train_cifar100_byteps_gc.py @@ -179,18 +179,24 @@ def main(): save_dir = '' save_period = 0 + # from https://github.com/weiaicunzai/pytorch-cifar/blob/master/conf/global_settings.py + CIFAR100_TRAIN_MEAN = [0.5070751592371323, + 0.48654887331495095, 0.4409178433670343] + CIFAR100_TRAIN_STD = [0.2673342858792401, + 0.2564384629170883, 0.27615047132568404] + transform_train = transforms.Compose([ gcv_transforms.RandomCrop(32, pad=4), transforms.RandomFlipLeftRight(), transforms.ToTensor(), - transforms.Normalize([0.4914, 0.4822, 0.4465], - [0.2023, 0.1994, 0.2010]) + transforms.Normalize(CIFAR100_TRAIN_MEAN, + CIFAR100_TRAIN_STD) ]) transform_test = transforms.Compose([ transforms.ToTensor(), - transforms.Normalize([0.4914, 0.4822, 0.4465], - [0.2023, 0.1994, 0.2010]) + transforms.Normalize(CIFAR100_TRAIN_MEAN, + CIFAR100_TRAIN_STD) ]) def test(ctx, val_data):