Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional SOTA ingredients on Classification Recipe #4493

Merged
merged 27 commits into from
Oct 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
063ca56
Update EMA every X iters.
datumbox Sep 28, 2021
02b4d42
Adding AdamW optimizer.
datumbox Sep 28, 2021
33a90f7
Adjusting EMA decay scheme.
datumbox Sep 28, 2021
cfdeede
Support custom weight decay for Normalization layers.
datumbox Sep 28, 2021
7ecc6d8
Fix identation bug.
datumbox Sep 28, 2021
0563f9e
Change EMA adjustment.
datumbox Sep 29, 2021
764fe02
Merge branch 'main' into references/optimizations
datumbox Sep 30, 2021
19e7d49
Merge branch 'main' into references/optimizations
datumbox Oct 1, 2021
d188ee0
Quality of life changes to faciliate testing
datumbox Oct 4, 2021
a630986
Merge branch 'main' into references/optimizations
datumbox Oct 5, 2021
6655dac
ufmt format
datumbox Oct 5, 2021
dc0edb9
Fixing imports.
datumbox Oct 5, 2021
e4a098f
Merge branch 'main' into references/optimizations
datumbox Oct 7, 2021
2e93296
Adding FixRes improvement.
datumbox Oct 8, 2021
dadb2f5
Merge branch 'main' into references/optimizations
datumbox Oct 8, 2021
6859fa2
Support EMA in store_model_weights.
datumbox Oct 13, 2021
8a9e1a8
Merge branch 'main' into references/optimizations
datumbox Oct 13, 2021
17eaf48
Merge branch 'main' into references/optimizations
datumbox Oct 14, 2021
950636e
Adding interpolation values.
datumbox Oct 15, 2021
9a6a443
Change train_crop_size.
datumbox Oct 17, 2021
2ce484a
Merge branch 'main' into references/optimizations
datumbox Oct 17, 2021
e699eca
Add interpolation option.
datumbox Oct 17, 2021
d861b33
Merge branch 'main' into references/optimizations
datumbox Oct 21, 2021
9ee69c4
Removing hardcoded interpolation and sizes from the scripts.
datumbox Oct 21, 2021
bc5a2bd
Fixing linter.
datumbox Oct 21, 2021
14a3323
Incorporating feedback from code review.
datumbox Oct 21, 2021
c3c65d2
Merge branch 'main' into references/optimizations
datumbox Oct 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ Here `$MODEL` is one of `alexnet`, `vgg11`, `vgg13`, `vgg16` or `vgg19`. Note
that `vgg11_bn`, `vgg13_bn`, `vgg16_bn`, and `vgg19_bn` include batch
normalization and thus are trained with the default parameters.

### Inception V3

The weights of the Inception V3 model are ported from the original paper rather than trained from scratch.

Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model use the following command:

```
torchrun --nproc_per_node=8 train.py --model inception_v3
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained
```

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we removed the hardcoding of parameters based on model names, we now need to provide extra parameters.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing in my mind, not related to this PR, is that if we can also let users pass kwargs to the models through command line? (in addition to the train.py arguments)
For example, when I train the ViT model, training from scratch and fine-tuning require 2 different heads, in this case I want to configure the representation_size differently, and currently I need to manually change the python defaults to reflect this.
wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will probably need to introduce more parameters to be able to do this. We will do it to enable your work but it's also part of the reason why the ArgumentParser is a poor solution. Hopefully this will be deprecated by the STL work you are preparing!

### ResNext-50 32x4d
```
torchrun --nproc_per_node=8 train.py\
Expand Down Expand Up @@ -79,6 +90,25 @@ The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](ht

The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564).

All models were trained using Bicubic interpolation and each have custom crop and resize sizes. To validate the models use the following commands:
```
torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --interpolation bicubic\
--val-resize-size 256 --val-crop-size 224 --train-crop-size 224 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --interpolation bicubic\
--val-resize-size 256 --val-crop-size 240 --train-crop-size 240 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --interpolation bicubic\
--val-resize-size 288 --val-crop-size 288 --train-crop-size 288 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --interpolation bicubic\
--val-resize-size 320 --val-crop-size 300 --train-crop-size 300 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --interpolation bicubic\
--val-resize-size 384 --val-crop-size 380 --train-crop-size 380 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --interpolation bicubic\
--val-resize-size 456 --val-crop-size 456 --train-crop-size 456 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --interpolation bicubic\
--val-resize-size 528 --val-crop-size 528 --train-crop-size 528 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bicubic\
--val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained
```

### RegNet

Expand Down Expand Up @@ -181,3 +211,8 @@ For post training quant, device is set to CPU. For training, the device is set t
```
python train_quantization.py --device='cpu' --test-only --backend='<backend>' --model='<model_name>'
```

For inception_v3 you need to pass the following extra parameters:
```
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299
```
9 changes: 5 additions & 4 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,22 @@ def __init__(
crop_size,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
hflip_prob=0.5,
auto_augment_policy=None,
random_erase_prob=0.0,
):
trans = [transforms.RandomResizedCrop(crop_size)]
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment())
trans.append(autoaugment.RandAugment(interpolation=interpolation))
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide())
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change on interpolation here is non-BC but I consider this a bug rather than a previous feature. On the previous recipe there was a mismatch between the interpolation used for resizing and the one used for AA methods.

trans.extend(
[
transforms.PILToTensor(),
Expand Down
120 changes: 74 additions & 46 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,20 @@
from torchvision.transforms.functional import InterpolationMode


def train_one_epoch(
model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None
):
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
Copy link
Contributor

@yiwen-song yiwen-song Oct 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I'm not a huge fan of passing the whole args to a single method (as it's not clear what are actually needed by this function), but I can see you do this just to reduce the number of arguments.
In the future we might want to add some type hints for all the args used in this script and also some documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, args is passed to reduce the number of parameters (merge 4 to 1). This is used in other places of the script such as, so I just use the same pattern:

def load_data(traindir, valdir, args):

Concerning type hints/documentation, I think you are right. For some reason most of the string args don't define it. I've raised a new #4694 issue to improve it.

model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))

header = "Epoch: [{}]".format(epoch)
for image, target in metric_logger.log_every(data_loader, print_freq, header):
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
start_time = time.time()
image, target = image.to(device), target.to(device)
output = model(image)

optimizer.zero_grad()
if amp:
if args.amp:
with torch.cuda.amp.autocast():
loss = criterion(output, target)
scaler.scale(loss).backward()
Expand All @@ -40,16 +38,19 @@ def train_one_epoch(
loss.backward()
optimizer.step()

if model_ema and i % args.model_ema_steps == 0:
model_ema.update_parameters(model)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving EMA updates on per iterration level than on epoch.

if epoch < args.lr_warmup_epochs:
# Reset ema buffer to keep copying weights during warmup period
model_ema.n_averaged.fill_(0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always copy the weights during warmup.


acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = image.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))

if model_ema:
model_ema.update_parameters(model)


def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
model.eval()
Expand Down Expand Up @@ -106,24 +107,8 @@ def _get_cache_path(filepath):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
resize_size, crop_size = 256, 224
interpolation = InterpolationMode.BILINEAR
if args.model == "inception_v3":
resize_size, crop_size = 342, 299
elif args.model.startswith("efficientnet_"):
sizes = {
"b0": (256, 224),
"b1": (256, 240),
"b2": (288, 288),
"b3": (320, 300),
"b4": (384, 380),
"b5": (456, 456),
"b6": (528, 528),
"b7": (600, 600),
}
e_type = args.model.replace("efficientnet_", "")
resize_size, crop_size = sizes[e_type]
interpolation = InterpolationMode.BICUBIC
val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
interpolation = InterpolationMode(args.interpolation)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove hardcoding of resize/crops based on model names. Instead use parameters.


print("Loading training data")
st = time.time()
Expand All @@ -138,7 +123,10 @@ def load_data(traindir, valdir, args):
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(
crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob
crop_size=train_crop_size,
interpolation=interpolation,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing the right interpolation value fixes the discrepancy bug discussed above.

auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob,
),
)
if args.cache_dataset:
Expand All @@ -156,7 +144,9 @@ def load_data(traindir, valdir, args):
else:
dataset_test = torchvision.datasets.ImageFolder(
valdir,
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation),
presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
),
)
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
Expand Down Expand Up @@ -224,26 +214,30 @@ def main(args):

criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

if args.norm_weight_decay is None:
parameters = model.parameters()
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate the BN normalization params from the rest so that we apply different weight decay. Improves by 0.1-0.2.


opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(
model.parameters(),
parameters,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name,
)
elif opt_name == "rmsprop":
optimizer = torch.optim.RMSprop(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
eps=0.0316,
alpha=0.9,
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding AdamW necessary for training ViT.

else:
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")

scaler = torch.cuda.amp.GradScaler() if args.amp else None

Expand Down Expand Up @@ -288,13 +282,23 @@ def main(args):

model_ema = None
if args.model_ema:
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay)
# Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
#
# total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
# We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
# adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
alpha = 1.0 - args.model_ema_decay
alpha = min(1.0, alpha * adjust)
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameterize EMA independently from epochs.


if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
if not args.test_only:
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quality of life improvement to avoid the super annoying error messages if you don't define all optimizer params during validation.

args.start_epoch = checkpoint["epoch"] + 1
if model_ema:
model_ema.load_state_dict(checkpoint["model_ema"])
Expand All @@ -303,18 +307,18 @@ def main(args):
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
else:
evaluate(model, criterion, data_loader_test, device=device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose which model to validate depending on the flag provided.

return

print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(
model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler
)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
Expand Down Expand Up @@ -362,6 +366,12 @@ def get_args_parser(add_help=True):
help="weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"--norm-weight-decay",
default=None,
type=float,
help="weight decay for Normalization layers (default: None, same value as --wd)",
)
parser.add_argument(
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
)
Expand Down Expand Up @@ -415,15 +425,33 @@ def get_args_parser(add_help=True):
parser.add_argument(
"--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
)
parser.add_argument(
"--model-ema-steps",
type=int,
default=32,
help="the number of iterations that controls how often to update the EMA model (default: 32)",
)
parser.add_argument(
"--model-ema-decay",
type=float,
default=0.9,
help="decay factor for Exponential Moving Average of model parameters(default: 0.9)",
default=0.99998,
help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reconfiguring default value of EMA now that we do per iter instead of per epoch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b q: Is this default value 0.99998 used most often?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good guess for ImageNet, considering the typical batch size for 8 gpus. The reason of changing this so drastically is because we switch from update per epoch to updates every X iters (X=32, configurable).

)
parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)
parser.add_argument(
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
)
parser.add_argument(
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
)
parser.add_argument(
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
)
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)

return parser

Expand Down
13 changes: 13 additions & 0 deletions references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,19 @@ def get_args_parser(add_help=True):
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training")

parser.add_argument(
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
)
parser.add_argument(
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
)
parser.add_argument(
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
)
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)

return parser


Expand Down
3 changes: 3 additions & 0 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,9 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T

# Load the weights to the model to validate that everything works
# and remove unnecessary weights (such as auxiliaries, etc)
if checkpoint_key == "model_ema":
del checkpoint[checkpoint_key]["n_averaged"]
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
model.load_state_dict(checkpoint[checkpoint_key], strict=strict)

tmp_path = os.path.join(output_dir, str(model.__hash__()))
Expand Down
14 changes: 12 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import torch
from common_utils import needs_cuda, cpu_and_gpu, assert_equal
from PIL import Image
from torch import Tensor
from torch import nn, Tensor
from torch.autograd import gradcheck
from torch.nn.modules.utils import _pair
from torchvision import ops
from torchvision import models, ops


class RoIOpTester(ABC):
Expand Down Expand Up @@ -1176,5 +1176,15 @@ def test_stochastic_depth(self, mode, p):
assert p_value > 0.0001


class TestUtils:
@pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
def test_split_normalization_params(self, norm_layer):
model = models.mobilenet_v3_large(norm_layer=norm_layer)
params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])

assert len(params[0]) == 92
assert len(params[1]) == 82


if __name__ == "__main__":
pytest.main([__file__])
Loading