Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Rebase #13757 to master (#15189)
Browse files Browse the repository at this point in the history
* Update .gitmodules

* Set ImageNet data augmentation by default

https://github.com/apache/incubator-mxnet/blob/a38278ddebfcc9459d64237086cd7977ec20c70e/example/image-classification/train_imagenet.py#L42

When I try to train imagenet with this line commented, the train-accuracy reaches 99% while the validation-accuracy is only less than 50% (single machine, 8 GPUs, global batchsize=2048, Resnet50). Absolutely this is overfitting.

Then I uncomment this line and try again with the same experiment settings. This time both train and validation accuracy converge to about 70%. 

Thus, it seems that this data augmentation is pretty important for ImageNet training. Perhaps it will be better to uncomment this as default, so that future developers won't get confused by the over-fit issue.

* Add argument for imagenet data augmentation

* Enable data-aug with argument

* Update .gitmodules
  • Loading branch information
ymjiang authored and wkcn committed Jul 11, 2019
1 parent 5ffd598 commit 554b196
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 3 additions & 1 deletion example/image-classification/common/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ def add_fit_args(parser):
train.add_argument('--profile-server-suffix', type=str, default='',
help='profile server actions into a file with name like rank1_ followed by this suffix \
during distributed training')
train.add_argument('--use-imagenet-data-augmentation', type=int, default=0,
help='enable data augmentation of ImageNet data, default disabled')
return train


Expand Down Expand Up @@ -335,4 +337,4 @@ def fit(args, network, data_loader, **kwargs):
if args.profile_server_suffix:
mx.profiler.set_state(state='run', profile_process='server')
if args.profile_worker_suffix:
mx.profiler.set_state(state='run', profile_process='worker')
mx.profiler.set_state(state='run', profile_process='worker')
4 changes: 2 additions & 2 deletions example/image-classification/train_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def set_imagenet_aug(aug):
fit.add_fit_args(parser)
data.add_data_args(parser)
data.add_data_aug_args(parser)
# uncomment to set standard augmentations for imagenet training
# set_imagenet_aug(parser)
parser.set_defaults(
# network
network = 'resnet',
Expand All @@ -56,6 +54,8 @@ def set_imagenet_aug(aug):
dtype = 'float32'
)
args = parser.parse_args()
if args.use_imagenet_data_augmentation:
set_imagenet_aug(parser)

# load network
from importlib import import_module
Expand Down

0 comments on commit 554b196

Please sign in to comment.