This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix lint, Add tests, fix bugs, add examples
- Loading branch information
1 parent
be5d0dd
commit 3e8ca54
Showing
7 changed files
with
458 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import os | ||
import logging | ||
import argparse | ||
import mxnet as mx | ||
from common import modelzoo | ||
import gluoncv | ||
from gluoncv.model_zoo import get_model | ||
from mxnet.contrib.amp import amp | ||
import numpy as np | ||
|
||
def download_model(model_name, logger=None): | ||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
model_path = os.path.join(dir_path, 'model') | ||
if logger is not None: | ||
logger.info('Downloading model {}... into path {}'.format(model_name, model_path)) | ||
return modelzoo.download_model(args.model, os.path.join(dir_path, 'model')) | ||
|
||
|
||
def save_symbol(fname, sym, logger=None): | ||
if logger is not None: | ||
logger.info('Saving symbol into file at {}'.format(fname)) | ||
sym.save(fname, remove_amp_cast=False) | ||
|
||
|
||
def save_params(fname, arg_params, aux_params, logger=None): | ||
if logger is not None: | ||
logger.info('Saving params into file at {}'.format(fname)) | ||
save_dict = {('arg:%s' % k): v.as_in_context(mx.cpu()) for k, v in arg_params.items()} | ||
save_dict.update({('aux:%s' % k): v.as_in_context(mx.cpu()) for k, v in aux_params.items()}) | ||
mx.nd.save(fname, save_dict) | ||
|
||
|
||
if __name__ == '__main__': | ||
symbolic_models = ['imagenet1k-resnet-152', | ||
'imagenet1k-resnet-18', | ||
'imagenet1k-resnet-34', | ||
'imagenet1k-resnet-50', | ||
'imagenet1k-resnet-101', | ||
'imagenet1k-resnext-50', | ||
'imagenet1k-resnext-101', | ||
'imagenet1k-resnext-101-64x4d', | ||
'imagenet11k-place365ch-resnet-152', | ||
'imagenet11k-place365ch-resnet-50'] | ||
gluon_models = ['resnet18_v1', | ||
'resnet50_v1', | ||
'resnet101_v1', | ||
'squeezenet1.0', | ||
'mobilenet1.0', | ||
'mobilenetv2_1.0', | ||
'inceptionv3'] | ||
models = symbolic_models + gluon_models | ||
|
||
parser = argparse.ArgumentParser(description='Convert a provided FP32 model to a mixed precision model') | ||
parser.add_argument('--model', type=str, choices=models) | ||
parser.add_argument('--run-dummy-inference', action='store_true', default=False, | ||
help='Will generate random input of shape (1, 3, 224, 224) ' | ||
'and run a dummy inference forward pass') | ||
parser.add_argument('--use-gluon-model', action='store_true', default=False, | ||
help='If enabled, will download pretrained model from Gluon-CV ' | ||
'and convert to mixed precision model ') | ||
args = parser.parse_args() | ||
logging.basicConfig() | ||
logger = logging.getLogger('logger') | ||
logger.setLevel(logging.INFO) | ||
|
||
if not args.use_gluon_model: | ||
assert args.model in symbolic_models, "Please choose one of the available symbolic models: {} \ | ||
If you want to use gluon use the script with --use-gluon-model".format(symbolic_models) | ||
|
||
prefix, epoch = download_model(model_name=args.model, logger=logger) | ||
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) | ||
result_sym, result_arg_params, result_aux_params = amp.convert_model(sym, arg_params, aux_params) | ||
sym_name = "%s-amp-symbol.json" % (prefix) | ||
save_symbol(sym_name, result_sym, logger) | ||
param_name = '%s-%04d.params' % (prefix + '-amp', epoch) | ||
save_params(param_name, result_arg_params, result_aux_params, logger) | ||
if args.run_dummy_inference: | ||
logger.info("Running inference on the mixed precision model with dummy input, batch size: 1") | ||
mod = mx.mod.Module(result_sym, data_names=['data'], label_names=['softmax_label'], context=mx.gpu(0)) | ||
mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]]) | ||
mod.set_params(arg_params, aux_params) | ||
mod.forward(mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))], | ||
label=[mx.nd.ones((1,))])) | ||
result = mod.get_outputs()[0].asnumpy() | ||
logger.info("Inference run successfully") | ||
else: | ||
assert args.model in gluon_models, "Please choose one of the available gluon models: {} \ | ||
If you want to use symbolic model instead, remove --use-gluon-model when running the script".format(gluon_models) | ||
net = gluoncv.model_zoo.get_model(args.model, pretrained=True) | ||
net.hybridize() | ||
result_before1 = net.forward(mx.nd.zeros((1, 3, 224, 224))) | ||
net = amp.convert_hybrid_block(net) | ||
net.export("{}-amp".format(args.model), remove_amp_cast=False) | ||
if args.run_dummy_inference: | ||
logger.info("Running inference on the mixed precision model with dummy inputs, batch size: 1") | ||
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0))) | ||
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0))) | ||
logger.info("Inference run successfully") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../image-classification/common |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.