Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
AMC supports resnet (#2876)
Browse files Browse the repository at this point in the history
  • Loading branch information
chicm-ms authored Oct 10, 2020
1 parent 392e55f commit 6126960
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 219 deletions.
10 changes: 10 additions & 0 deletions docs/en_US/Compression/Pruner.md
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,16 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod
.. autoclass:: nni.compression.torch.AMCPruner
```

### Reproduced Experiment

We implemented one of the experiments in [AMC: AutoML for Model Compression and Acceleration on Mobile Devices](https://arxiv.org/pdf/1802.03494.pdf), we pruned **MobileNet** to 50% FLOPS for ImageNet in the paper. Our experiments results are as follows:

| Model | Top 1 acc.(paper/ours) | Top 5 acc. (paper/ours) | FLOPS |
| ------------- | --------------| -------------- | ----- |
| MobileNet | 70.5% / 69.9% | 89.3% / 89.1% | 50% |

The experiments code can be found at [examples/model_compress]( https://github.com/microsoft/nni/tree/master/examples/model_compress/amc/)

## ADMM Pruner
Alternating Direction Method of Multipliers (ADMM) is a mathematical optimization technique,
by decomposing the original nonconvex problem into two subproblems that can be solved iteratively. In weight pruning problem, these two subproblems are solved via 1) gradient descent algorithm and 2) Euclidean projection respectively.
Expand Down
40 changes: 21 additions & 19 deletions examples/model_compress/amc/amc_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
import torch.nn as nn

from torchvision.models import resnet
from nni.compression.torch import AMCPruner
from data import get_split_dataset
from utils import AverageMeter, accuracy
Expand All @@ -16,7 +16,8 @@

def parse_args():
parser = argparse.ArgumentParser(description='AMC search script')
parser.add_argument('--model_type', default='mobilenet', type=str, choices=['mobilenet', 'mobilenetv2'], help='model to prune')
parser.add_argument('--model_type', default='mobilenet', type=str, choices=['mobilenet', 'mobilenetv2', 'resnet18', 'resnet34', 'resnet50'],
help='model to prune')
parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)')
parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size')
parser.add_argument('--data_root', default='./cifar10', type=str, help='dataset path')
Expand All @@ -28,27 +29,29 @@ def parse_args():
parser.add_argument('--train_episode', default=800, type=int, help='number of training episode')
parser.add_argument('--n_gpu', default=1, type=int, help='number of gpu to use')
parser.add_argument('--n_worker', default=16, type=int, help='number of data loader worker')
parser.add_argument('--job', default='train_export', type=str, choices=['train_export', 'export_only'],
help='search best pruning policy and export or just export model with searched policy')
parser.add_argument('--export_path', default=None, type=str, help='path for exporting models')
parser.add_argument('--searched_model_path', default=None, type=str, help='path for searched best wrapped model')
parser.add_argument('--suffix', default=None, type=str, help='suffix of auto-generated log directory')

return parser.parse_args()


def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1):
if model == 'mobilenet' and dataset == 'imagenet':
from mobilenet import MobileNet
net = MobileNet(n_class=1000)
elif model == 'mobilenetv2' and dataset == 'imagenet':
from mobilenet_v2 import MobileNetV2
net = MobileNetV2(n_class=1000)
elif model == 'mobilenet' and dataset == 'cifar10':
if dataset == 'imagenet':
n_class = 1000
elif dataset == 'cifar10':
n_class = 10
else:
raise ValueError('unsupported dataset')

if model == 'mobilenet':
from mobilenet import MobileNet
net = MobileNet(n_class=10)
elif model == 'mobilenetv2' and dataset == 'cifar10':
net = MobileNet(n_class=n_class)
elif model == 'mobilenetv2':
from mobilenet_v2 import MobileNetV2
net = MobileNetV2(n_class=10)
net = MobileNetV2(n_class=n_class)
elif model.startswith('resnet'):
net = resnet.__dict__[model](pretrained=True)
in_features = net.fc.in_features
net.fc = nn.Linear(in_features, n_class)
else:
raise NotImplementedError
if checkpoint_path:
Expand Down Expand Up @@ -130,7 +133,6 @@ def validate(val_loader, model, verbose=False):
}]
pruner = AMCPruner(
model, config_list, validate, val_loader, model_type=args.model_type, dataset=args.dataset,
train_episode=args.train_episode, job=args.job, export_path=args.export_path,
searched_model_path=args.searched_model_path,
flops_ratio=args.flops_ratio, lbound=args.lbound, rbound=args.rbound)
train_episode=args.train_episode, flops_ratio=args.flops_ratio, lbound=args.lbound,
rbound=args.rbound, suffix=args.suffix)
pruner.compress()
34 changes: 19 additions & 15 deletions examples/model_compress/amc/amc_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from nni.compression.torch.pruning.amc.lib.net_measure import measure_model
from nni.compression.torch.pruning.amc.lib.utils import get_output_folder
from nni.compression.torch import ModelSpeedup

from data import get_dataset
from utils import AverageMeter, accuracy, progress_bar
Expand All @@ -28,17 +29,19 @@ def parse_args():
parser = argparse.ArgumentParser(description='AMC train / fine-tune script')
parser.add_argument('--model_type', default='mobilenet', type=str, help='name of the model to train')
parser.add_argument('--dataset', default='cifar10', type=str, help='name of the dataset to train')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--n_gpu', default=1, type=int, help='number of GPUs to use')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('--n_worker', default=4, type=int, help='number of data loader worker')
parser.add_argument('--lr_type', default='exp', type=str, help='lr scheduler (exp/cos/step3/fixed)')
parser.add_argument('--n_epoch', default=50, type=int, help='number of epochs to train')
parser.add_argument('--lr', default=0.05, type=float, help='learning rate')
parser.add_argument('--n_gpu', default=4, type=int, help='number of GPUs to use')
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
parser.add_argument('--n_worker', default=32, type=int, help='number of data loader worker')
parser.add_argument('--lr_type', default='cos', type=str, help='lr scheduler (exp/cos/step3/fixed)')
parser.add_argument('--n_epoch', default=150, type=int, help='number of epochs to train')
parser.add_argument('--wd', default=4e-5, type=float, help='weight decay')
parser.add_argument('--seed', default=None, type=int, help='random seed to set')
parser.add_argument('--data_root', default='./data', type=str, help='dataset path')
# resume
parser.add_argument('--ckpt_path', default=None, type=str, help='checkpoint path to fine tune')
parser.add_argument('--mask_path', default=None, type=str, help='mask path for speedup')

# run eval
parser.add_argument('--eval', action='store_true', help='Simply run eval')
parser.add_argument('--calc_flops', action='store_true', help='Calculate flops')
Expand All @@ -56,20 +59,21 @@ def get_model(args):
raise NotImplementedError

if args.model_type == 'mobilenet':
net = MobileNet(n_class=n_class).cuda()
net = MobileNet(n_class=n_class)
elif args.model_type == 'mobilenetv2':
net = MobileNetV2(n_class=n_class).cuda()
net = MobileNetV2(n_class=n_class)
else:
raise NotImplementedError

if args.ckpt_path is not None:
# the checkpoint can be a saved whole model object exported by amc_search.py, or a state_dict
# the checkpoint can be state_dict exported by amc_search.py or saved by amc_train.py
print('=> Loading checkpoint {} ..'.format(args.ckpt_path))
ckpt = torch.load(args.ckpt_path)
if type(ckpt) == dict:
net.load_state_dict(ckpt['state_dict'])
else:
net = ckpt
net.load_state_dict(torch.load(args.ckpt_path))
if args.mask_path is not None:
SZ = 224 if args.dataset == 'imagenet' else 32
data = torch.randn(2, 3, SZ, SZ)
ms = ModelSpeedup(net, data, args.mask_path)
ms.speedup_model()

net.to(args.device)
if torch.cuda.is_available() and args.n_gpu > 1:
Expand Down Expand Up @@ -204,7 +208,7 @@ def save_checkpoint(state, is_best, checkpoint_dir='.'):

if args.calc_flops:
IMAGE_SIZE = 224 if args.dataset == 'imagenet' else 32
n_flops, n_params = measure_model(net, IMAGE_SIZE, IMAGE_SIZE)
n_flops, n_params = measure_model(net, IMAGE_SIZE, IMAGE_SIZE, args.device)
print('=> Model Parameter: {:.3f} M, FLOPs: {:.3f}M'.format(n_params / 1e6, n_flops / 1e6))
exit(0)

Expand Down
114 changes: 41 additions & 73 deletions src/sdk/pynni/nni/compression/torch/pruning/amc/amc_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT license.

import os
import logging
from copy import deepcopy
from argparse import Namespace
import numpy as np
Expand All @@ -15,6 +16,8 @@

torch.backends.cudnn.deterministic = True

_logger = logging.getLogger(__name__)

class AMCPruner(Pruner):
"""
A pytorch implementation of AMC: AutoML for Model Compression and Acceleration on Mobile Devices.
Expand All @@ -38,13 +41,6 @@ class AMCPruner(Pruner):
Data loader of validation dataset.
suffix: str
suffix to help you remember what experiment you ran. Default: None.
job: str
train_export: search best pruned model and export after search.
export_only: export a searched model, searched_model_path and export_path must be specified.
searched_model_path: str
when job == export_only, use searched_model_path to specify the path of the searched model.
export_path: str
path for exporting models
# parameters for pruning environment
model_type: str
Expand Down Expand Up @@ -118,9 +114,6 @@ def __init__(
evaluator,
val_loader,
suffix=None,
job='train_export',
export_path=None,
searched_model_path=None,
model_type='mobilenet',
dataset='cifar10',
flops_ratio=0.5,
Expand Down Expand Up @@ -149,9 +142,8 @@ def __init__(
epsilon=50000,
seed=None):

self.job = job
self.searched_model_path = searched_model_path
self.export_path = export_path
self.val_loader = val_loader
self.evaluator = evaluator

if seed is not None:
np.random.seed(seed)
Expand All @@ -165,11 +157,9 @@ def __init__(
# build folder and logs
base_folder_name = '{}_{}_r{}_search'.format(model_type, dataset, flops_ratio)
if suffix is not None:
base_folder_name = base_folder_name + '_' + suffix
self.output_dir = get_output_folder(output_dir, base_folder_name)

if self.export_path is None:
self.export_path = os.path.join(self.output_dir, '{}_r{}_exported.pth'.format(model_type, flops_ratio))
self.output_dir = os.path.join(output_dir, base_folder_name + '-' + suffix)
else:
self.output_dir = get_output_folder(output_dir, base_folder_name)

self.env_args = Namespace(
model_type=model_type,
Expand All @@ -182,47 +172,42 @@ def __init__(
channel_round=channel_round,
output=self.output_dir
)

self.env = ChannelPruningEnv(
self, evaluator, val_loader, checkpoint, args=self.env_args)

if self.job == 'train_export':
print('=> Saving logs to {}'.format(self.output_dir))
self.tfwriter = SummaryWriter(log_dir=self.output_dir)
self.text_writer = open(os.path.join(self.output_dir, 'log.txt'), 'w')
print('=> Output path: {}...'.format(self.output_dir))

nb_states = self.env.layer_embedding.shape[1]
nb_actions = 1 # just 1 action here

rmsize = rmsize * len(self.env.prunable_idx) # for each layer
print('** Actual replay buffer size: {}'.format(rmsize))

self.ddpg_args = Namespace(
hidden1=hidden1,
hidden2=hidden2,
lr_c=lr_c,
lr_a=lr_a,
warmup=warmup,
discount=discount,
bsize=bsize,
rmsize=rmsize,
window_length=window_length,
tau=tau,
init_delta=init_delta,
delta_decay=delta_decay,
max_episode_length=max_episode_length,
debug=debug,
train_episode=train_episode,
epsilon=epsilon
)
self.agent = DDPG(nb_states, nb_actions, self.ddpg_args)
_logger.info('=> Saving logs to %s', self.output_dir)
self.tfwriter = SummaryWriter(log_dir=self.output_dir)
self.text_writer = open(os.path.join(self.output_dir, 'log.txt'), 'w')
_logger.info('=> Output path: %s...', self.output_dir)

nb_states = self.env.layer_embedding.shape[1]
nb_actions = 1 # just 1 action here

rmsize = rmsize * len(self.env.prunable_idx) # for each layer
_logger.info('** Actual replay buffer size: %d', rmsize)

self.ddpg_args = Namespace(
hidden1=hidden1,
hidden2=hidden2,
lr_c=lr_c,
lr_a=lr_a,
warmup=warmup,
discount=discount,
bsize=bsize,
rmsize=rmsize,
window_length=window_length,
tau=tau,
init_delta=init_delta,
delta_decay=delta_decay,
max_episode_length=max_episode_length,
debug=debug,
train_episode=train_episode,
epsilon=epsilon
)
self.agent = DDPG(nb_states, nb_actions, self.ddpg_args)


def compress(self):
if self.job == 'train_export':
self.train(self.ddpg_args.train_episode, self.agent, self.env, self.output_dir)
self.export_pruned_model()
self.train(self.ddpg_args.train_episode, self.agent, self.env, self.output_dir)

def train(self, num_episode, agent, env, output_dir):
agent.is_training = True
Expand Down Expand Up @@ -263,12 +248,11 @@ def train(self, num_episode, agent, env, output_dir):
observation = deepcopy(observation2)

if done: # end of episode
print(
'#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}'.format(
_logger.info(
'#%d: episode_reward: %.4f acc: %.4f, ratio: %.4f',
episode, episode_reward,
info['accuracy'],
info['compress_ratio']
)
)
self.text_writer.write(
'#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}\n'.format(
Expand Down Expand Up @@ -310,19 +294,3 @@ def train(self, num_episode, agent, env, output_dir):
self.text_writer.write('best reward: {}\n'.format(env.best_reward))
self.text_writer.write('best policy: {}\n'.format(env.best_strategy))
self.text_writer.close()

def export_pruned_model(self):
if self.searched_model_path is None:
wrapper_model_ckpt = os.path.join(self.output_dir, 'best_wrapped_model.pth')
else:
wrapper_model_ckpt = self.searched_model_path
self.env.reset()
self.bound_model.load_state_dict(torch.load(wrapper_model_ckpt))

print('validate searched model:', self.env._validate(self.env._val_loader, self.env.model))
self.env.export_model()
self._unwrap_model()
print('validate exported model:', self.env._validate(self.env._val_loader, self.env.model))

torch.save(self.bound_model, self.export_path)
print('exported model saved to: {}'.format(self.export_path))
Loading

0 comments on commit 6126960

Please sign in to comment.