Skip to content

Commit

Permalink
cifar: update (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
jasperzhong committed Jun 23, 2020
1 parent c17bfb6 commit 7f12e15
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions example/mxnet/train_cifar100_byteps_gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,18 +179,24 @@ def main():
save_dir = ''
save_period = 0

# from https://github.com/weiaicunzai/pytorch-cifar/blob/master/conf/global_settings.py
CIFAR100_TRAIN_MEAN = [0.5070751592371323,
0.48654887331495095, 0.4409178433670343]
CIFAR100_TRAIN_STD = [0.2673342858792401,
0.2564384629170883, 0.27615047132568404]

transform_train = transforms.Compose([
gcv_transforms.RandomCrop(32, pad=4),
transforms.RandomFlipLeftRight(),
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010])
transforms.Normalize(CIFAR100_TRAIN_MEAN,
CIFAR100_TRAIN_STD)
])

transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010])
transforms.Normalize(CIFAR100_TRAIN_MEAN,
CIFAR100_TRAIN_STD)
])

def test(ctx, val_data):
Expand Down

0 comments on commit 7f12e15

Please sign in to comment.