Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix lint, Add tests, fix bugs, add examples
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudh2290 committed Jun 4, 2019
1 parent be5d0dd commit 3e8ca54
Show file tree
Hide file tree
Showing 7 changed files with 458 additions and 43 deletions.
98 changes: 98 additions & 0 deletions example/automatic-mixed-precision/amp_model_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
import logging
import argparse
import mxnet as mx
from common import modelzoo
import gluoncv
from gluoncv.model_zoo import get_model
from mxnet.contrib.amp import amp
import numpy as np

def download_model(model_name, logger=None):
dir_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(dir_path, 'model')
if logger is not None:
logger.info('Downloading model {}... into path {}'.format(model_name, model_path))
return modelzoo.download_model(args.model, os.path.join(dir_path, 'model'))


def save_symbol(fname, sym, logger=None):
if logger is not None:
logger.info('Saving symbol into file at {}'.format(fname))
sym.save(fname, remove_amp_cast=False)


def save_params(fname, arg_params, aux_params, logger=None):
if logger is not None:
logger.info('Saving params into file at {}'.format(fname))
save_dict = {('arg:%s' % k): v.as_in_context(mx.cpu()) for k, v in arg_params.items()}
save_dict.update({('aux:%s' % k): v.as_in_context(mx.cpu()) for k, v in aux_params.items()})
mx.nd.save(fname, save_dict)


if __name__ == '__main__':
symbolic_models = ['imagenet1k-resnet-152',
'imagenet1k-resnet-18',
'imagenet1k-resnet-34',
'imagenet1k-resnet-50',
'imagenet1k-resnet-101',
'imagenet1k-resnext-50',
'imagenet1k-resnext-101',
'imagenet1k-resnext-101-64x4d',
'imagenet11k-place365ch-resnet-152',
'imagenet11k-place365ch-resnet-50']
gluon_models = ['resnet18_v1',
'resnet50_v1',
'resnet101_v1',
'squeezenet1.0',
'mobilenet1.0',
'mobilenetv2_1.0',
'inceptionv3']
models = symbolic_models + gluon_models

parser = argparse.ArgumentParser(description='Convert a provided FP32 model to a mixed precision model')
parser.add_argument('--model', type=str, choices=models)
parser.add_argument('--run-dummy-inference', action='store_true', default=False,
help='Will generate random input of shape (1, 3, 224, 224) '
'and run a dummy inference forward pass')
parser.add_argument('--use-gluon-model', action='store_true', default=False,
help='If enabled, will download pretrained model from Gluon-CV '
'and convert to mixed precision model ')
args = parser.parse_args()
logging.basicConfig()
logger = logging.getLogger('logger')
logger.setLevel(logging.INFO)

if not args.use_gluon_model:
assert args.model in symbolic_models, "Please choose one of the available symbolic models: {} \
If you want to use gluon use the script with --use-gluon-model".format(symbolic_models)

prefix, epoch = download_model(model_name=args.model, logger=logger)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
result_sym, result_arg_params, result_aux_params = amp.convert_model(sym, arg_params, aux_params)
sym_name = "%s-amp-symbol.json" % (prefix)
save_symbol(sym_name, result_sym, logger)
param_name = '%s-%04d.params' % (prefix + '-amp', epoch)
save_params(param_name, result_arg_params, result_aux_params, logger)
if args.run_dummy_inference:
logger.info("Running inference on the mixed precision model with dummy input, batch size: 1")
mod = mx.mod.Module(result_sym, data_names=['data'], label_names=['softmax_label'], context=mx.gpu(0))
mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]])
mod.set_params(arg_params, aux_params)
mod.forward(mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))],
label=[mx.nd.ones((1,))]))
result = mod.get_outputs()[0].asnumpy()
logger.info("Inference run successfully")
else:
assert args.model in gluon_models, "Please choose one of the available gluon models: {} \
If you want to use symbolic model instead, remove --use-gluon-model when running the script".format(gluon_models)
net = gluoncv.model_zoo.get_model(args.model, pretrained=True)
net.hybridize()
result_before1 = net.forward(mx.nd.zeros((1, 3, 224, 224)))
net = amp.convert_hybrid_block(net)
net.export("{}-amp".format(args.model), remove_amp_cast=False)
if args.run_dummy_inference:
logger.info("Running inference on the mixed precision model with dummy inputs, batch size: 1")
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0)))
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0)))
logger.info("Inference run successfully")
1 change: 1 addition & 0 deletions example/automatic-mixed-precision/common
142 changes: 102 additions & 40 deletions python/mxnet/contrib/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ...symbol import Symbol
from ...symbol import contrib as symbol_contrib
from ... import ndarray
from ...ndarray import NDArray, _DTYPE_NP_TO_MX
from ...ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
from . import lists
from ...gluon import trainer
from ... import base
Expand Down Expand Up @@ -392,25 +392,12 @@ def convert_symbol(sym, target_dtype="float16", target_dtype_ops=None,
else:
fp32_ops = lists.symbol.FP32_FUNCS

common_ops = set(target_dtype_ops) & set(fp32_ops)
assert len(common_ops) == 0, "Ops cannot be in both FP16 list and FP32 list {}".format(common_ops)

combined_ops = set(target_dtype_ops + fp32_ops)
all_fp16_fp32_ops = set(lists.symbol.FP16_FUNCS + lists.symbol.FP32_FUNCS + lists.symbol.FP16_FP32_FUNCS)

assert combined_ops.issubset(all_fp16_fp32_ops), "Can only choose ops from one of the three lists " \
"for fp16_ops and fp32_ops" \
" 1. amp.list_fp16_ops()" \
" 2. amp.list_fp32_ops()" \
" 3. amp.list_fp16_fp32_ops()"

widest_dtype_ops = lists.symbol.WIDEST_TYPE_CASTS

if conditional_fp32_ops is not None:
assert isinstance(conditional_fp32_ops, list) << "conditional_fp32_ops should be a list"
assert isinstance(conditional_fp32_ops, list), "conditional_fp32_ops should be a list"
else:
conditional_fp32_ops = lists.symbol.CONDITIONAL_FP32_FUNCS

original_conditional_op_names = []
conditional_op_names = []
param_names = []
param_vals = []
Expand All @@ -429,6 +416,36 @@ def convert_symbol(sym, target_dtype="float16", target_dtype_ops=None,
else:
excluded_sym_names = []

for original_conditional_fp32_op in lists.symbol.CONDITIONAL_FP32_FUNCS:
original_conditional_op_names.append(original_conditional_fp32_op[0])



common_ops = set(target_dtype_ops) & set(fp32_ops)
assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
"Common ops in target_dtype_ops and fp32_ops {}".format(common_ops)
common_ops = set(target_dtype_ops) & set(conditional_op_names)
assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
"Common ops in target_dtype_ops and conditional_fp32_ops {}".format(common_ops)
common_ops = set(conditional_op_names) & set(fp32_ops)
assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
"Common ops in fp32_ops and conditional_fp32_ops {}".format(common_ops)

combined_ops = set(target_dtype_ops + fp32_ops + conditional_op_names)
all_fp16_fp32_ops = set(lists.symbol.FP16_FUNCS + lists.symbol.FP32_FUNCS
+ lists.symbol.FP16_FP32_FUNCS + original_conditional_op_names)

illegal_ops = combined_ops - all_fp16_fp32_ops
assert not illegal_ops, '''Can only choose ops from one of the three lists
for fp16_ops and fp32_ops
1. amp.list_fp16_ops()
2. amp.list_fp32_ops()
3. amp.list_fp16_fp32_ops()
4. amp.list_conditional_fp32_ops()
Op %s not in any of them''' % (illegal_ops)

widest_dtype_ops = lists.symbol.WIDEST_TYPE_CASTS

target_dtype = _DTYPE_NP_TO_MX[np.dtype(target_dtype).type]

attr_dict = sym.attr_dict()
Expand Down Expand Up @@ -478,8 +495,8 @@ def convert_model(sym, arg_params, aux_params, target_dtype="float16", target_dt
fp32_ops=None, conditional_fp32_ops=None, excluded_sym_names=None):
"""API for converting a model from FP32 model to a mixed precision model.
MXNet tries to convert the FP32 model to mixed precision model by adding
cast layers using amp_cast and amp_multicast operators. The decision on
which cast layer to add is based on hardcoded lists for Automatic Mixed Precision
cast layers using amp_cast and amp_multicast operators which can be used for inference use cases.
The decision on which cast layer to add is based on hardcoded lists for Automatic Mixed Precision
in MXNet. These lists can be overridden by the user by providing their own lists
using : targe_precision_ops, fp32_ops, widest_precision_ops, conditional_fp32_ops
Expand Down Expand Up @@ -524,13 +541,26 @@ def convert_model(sym, arg_params, aux_params, target_dtype="float16", target_dt
sym = convert_symbol(sym, target_dtype, target_dtype_ops,
fp32_ops, conditional_fp32_ops,
excluded_sym_names, data_names)
attr_dict = sym.attr_dict()
for sym_name in sym.list_arguments():
if sym_name in attr_dict and "__dtype__" in attr_dict[sym_name]:
if attr_dict[sym_name]["__dtype__"] != "-1":
typ = _DTYPE_MX_TO_NP[int(attr_dict[sym_name]["__dtype__"])]
arg_params[sym_name] = arg_params[sym_name].astype(typ)

for sym_name in sym.list_auxiliary_states():
if sym_name in attr_dict and "__dtype__" in attr_dict[sym_name]:
if attr_dict[sym_name]["__dtype__"] != "-1":
typ = _DTYPE_MX_TO_NP[int(attr_dict[sym_name]["__dtype__"])]
aux_params[sym_name] = aux_params[sym_name].astype(typ)

return sym, arg_params, aux_params

def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
fp32_ops=None, conditional_fp32_ops=None,
excluded_sym_names=None, ctx=gpu(0)):
"""Given a hybrid block/symbol block representing a FP32 model and a target_dtype,
return a block with mixed precision support which can be used for inference.
return a block with mixed precision support which can be used for inference use cases.
Parameters
----------
Expand All @@ -550,29 +580,56 @@ def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
excluded_sym_names : list of strs
A list of strings that represent the names of symbols that users want to exclude
from being quantized
ctx : Context
Context on which model parameters should live
"""
from ...gluon import HybridBlock, SymbolBlock
if isinstance(block, HybridBlock):
inputs, sym = block._cached_graph
converted_sym = convert_symbol(sym, target_dtype, target_dtype_ops,
fp32_ops, conditional_fp32_ops,
excluded_sym_names)

arg_names = set(converted_sym.list_arguments())
aux_names = set(converted_sym.list_auxiliary_states())
arg_dict = {}
# collect params
for name, param in block.collect_params().items():
if name in arg_names:
arg_dict['arg:%s'%name] = param._reduce()
else:
assert name in aux_names
arg_dict['aux:%s'%name] = param._reduce()

ret = SymbolBlock(converted_sym, inputs)

ret.collect_params().load_dict(arg_dict, ctx=ctx)
return ret
assert isinstance(block, HybridBlock), "block input should be a HybridBlock"
if not block._cached_graph:
raise RuntimeError(
"Please first call block.hybridize() and then run forward with "
"this block at least once before calling export.")

inputs, sym = block._cached_graph
input_names = []
for inp in inputs:
input_names.append(inp.name)
converted_sym = convert_symbol(sym, target_dtype, target_dtype_ops,
fp32_ops, conditional_fp32_ops,
excluded_sym_names, data_names=input_names)

arg_names = set(converted_sym.list_arguments())
aux_names = set(converted_sym.list_auxiliary_states())
arg_dict = {}
attr_dict = converted_sym.attr_dict()
# collect params
for name, param in block.collect_params().items():
if name in arg_names:
arg_dict['arg:%s'%name] = param._reduce()
if name in attr_dict and "__dtype__" in attr_dict[name]:
if attr_dict[name]["__dtype__"] != "-1":
typ = _DTYPE_MX_TO_NP[int(attr_dict[name]["__dtype__"])]
arg_dict['arg:%s'%name] = arg_dict['arg:%s'%name].astype(typ)
else:
assert name in aux_names
arg_dict['aux:%s'%name] = param._reduce()
if name in attr_dict and "__dtype__" in attr_dict[name]:
if attr_dict[name]["__dtype__"] != "-1":
typ = _DTYPE_MX_TO_NP[int(attr_dict[name]["__dtype__"])]
arg_dict['aux:%s'%name] = arg_dict['aux:%s'%name].astype(typ)

ret = SymbolBlock(converted_sym, inputs)
for key, param in ret.collect_params().items():
arg_param_name = "arg:%s" % key
if arg_param_name in arg_dict and param.dtype != arg_dict[arg_param_name].dtype:
param.cast(arg_dict[arg_param_name].dtype)

aux_param_name = "aux:%s" % key
if aux_param_name in arg_dict and param.dtype != arg_dict[aux_param_name].dtype:
param.cast(arg_dict[aux_param_name].dtype)

ret.collect_params().load_dict(arg_dict, ctx=ctx)
return ret

def list_fp16_ops():
"""Get the default list of FP16 ops for AMP
Expand All @@ -588,3 +645,8 @@ def list_fp16_fp32_ops():
"""Get the default list of ops which run in both FP16 and FP32
"""
return lists.symbol.FP16_FP32_FUNCS

def list_conditional_fp32_ops():
"""Get the conditional fp32 ops list
"""
return lists.symbol.CONDITIONAL_FP32_FUNCS
2 changes: 1 addition & 1 deletion python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ def load_dict(self, param_dict, ctx=None, allow_missing=False,
loaded = [(k[4:] if k.startswith('arg:') or k.startswith('aux:') else k, v) \
for k, v in param_dict.items()] if isinstance(param_dict, dict) else param_dict
arg_dict = {restore_prefix+k: v for k, v in loaded}
error_str = "file: %" % (filename) if filename else "param_dict"
error_str = "file: %s" % (filename) if filename else "param_dict"
if not allow_missing:
for name in self.keys():
assert name in arg_dict, \
Expand Down
Loading

0 comments on commit 3e8ca54

Please sign in to comment.