Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Conversion from FP32 model to Mixed Precision model #15118

Merged
merged 45 commits into from
Jun 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
36e5579
Initial AMP commit
anirudh2290 May 15, 2019
70409d0
Fix
anirudh2290 May 16, 2019
ae3734f
Merge AMP Changes
anirudh2290 May 25, 2019
9f041cc
AMP Changes to support conditional op names switch
anirudh2290 May 30, 2019
4dce69e
Add example and fix issues with AMP conversion
anirudh2290 Jun 1, 2019
8d63335
Remove amp convert symbol test
anirudh2290 Jun 1, 2019
e526c16
Fix comment for inference use case
anirudh2290 Jun 1, 2019
888daa7
Remove input_names for convert_hybrid_block
anirudh2290 Jun 1, 2019
ea8b220
Check all conditions
anirudh2290 Jun 1, 2019
eded365
Fix lint
anirudh2290 Jun 1, 2019
be5d0dd
Fix error_str for load_dict
anirudh2290 Jun 1, 2019
3e8ca54
Fix lint, Add tests, fix bugs, add examples
anirudh2290 Jun 4, 2019
7640f50
Fix warnings
anirudh2290 Jun 4, 2019
42967e8
Add license for example script
anirudh2290 Jun 4, 2019
f502d74
Remove gpu test and move tests to test_contrib_amp
anirudh2290 Jun 4, 2019
f7d051d
Clean up AMP tests
anirudh2290 Jun 4, 2019
57060e7
Add additional comments, add tutorial
anirudh2290 Jun 5, 2019
5a4b1f7
Move the test to gpu dir
anirudh2290 Jun 5, 2019
7e1feae
Make the code python3 compatible
anirudh2290 Jun 5, 2019
ea7dd32
Upgrade archive utility, fixes: #15084
anirudh2290 Jun 7, 2019
94156b6
Allow AR path to be chosen by user
anirudh2290 Jun 7, 2019
9ef7bd3
Use current_context in tutorial
anirudh2290 Jun 10, 2019
b5173b9
Update __all__
anirudh2290 Jun 10, 2019
f9d09a4
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
anirudh2290 Jun 10, 2019
9297970
Merge with load params API changes
anirudh2290 Jun 10, 2019
eb186d0
Revert "Allow AR path to be chosen by user"
anirudh2290 Jun 10, 2019
80dd7bc
Revert "Upgrade archive utility, fixes: #15084"
anirudh2290 Jun 10, 2019
1ea508f
Set numpy dtype to float32
anirudh2290 Jun 11, 2019
8e52789
Address review comments
anirudh2290 Jun 17, 2019
61e942f
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
anirudh2290 Jun 17, 2019
bba14e0
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
anirudh2290 Jun 20, 2019
9c72372
Add range based for
anirudh2290 Jun 20, 2019
89ea0cc
Change quantized to low precision
anirudh2290 Jun 20, 2019
ed1b814
Fix lint
anirudh2290 Jun 20, 2019
65ebc74
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
anirudh2290 Jun 20, 2019
43bc0c9
Fix pylint
anirudh2290 Jun 20, 2019
b70ab2f
Forward args for Node::Create
anirudh2290 Jun 20, 2019
ed82db1
Fixes
anirudh2290 Jun 21, 2019
1200dde
Add dtype casting wherever needed
anirudh2290 Jun 27, 2019
85a50e2
Fix lint in source
anirudh2290 Jun 27, 2019
22d3a76
Add cast_optional_params to example
anirudh2290 Jun 27, 2019
383d664
Tweak example
anirudh2290 Jun 27, 2019
2480273
Add README
anirudh2290 Jun 27, 2019
8df637d
Add README
anirudh2290 Jun 27, 2019
9903222
Add cast_optional_params test for convert_model and convert_hybrid_bloc
anirudh2290 Jun 28, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions docs/tutorials/amp/amp_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ For demonstration purposes we will use synthetic data loader.


```python
import os
import logging
import warnings
import time
import mxnet as mx
import mxnet.gluon as gluon
from mxnet import autograd
from mxnet.test_utils import download_model
import gluoncv as gcv
from gluoncv.model_zoo import get_model

Expand Down Expand Up @@ -249,6 +251,46 @@ for epoch in range(1):

We got 60% speed increase from 3 additional lines of code!

## Inference with AMP

To do inference with mixed precision for a trained model in FP32, you can use the conversion APIs: `amp.convert_model` for symbolic model and `amp.convert_hybrid_block` for gluon models. The conversion APIs will take the FP32 model as input and will return a mixed precision model, which can be used to run inference. Below, we demonstrate for a gluon model and a symbolic model: 1. Conversion from FP32 model to mixed precision model 2. Run inference on the mixed precision model.

```python
with mx.Context(mx.gpu(0)):
anirudh2290 marked this conversation as resolved.
Show resolved Hide resolved
# Below is an example of converting a gluon hybrid block to a mixed precision block
model = get_model("resnet50_v1")
model.collect_params().initialize(ctx=mx.current_context())
model.hybridize()
model(mx.nd.zeros((1, 3, 224, 224)))
converted_model = amp.convert_hybrid_block(model)

# Run dummy inference with the converted gluon model
result = converted_model.forward(mx.nd.random.uniform(shape=(1, 3, 224, 224),
dtype=np.float32))

# Below is an example of converting a symbolic model to a mixed precision model
dir_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(dir_path, 'model')
if not os.path.isdir(model_path):
os.mkdir(model_path)
prefix, epoch = mx.test_utils.download_model("imagenet1k-resnet-18", dst_dir=model_path)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
result_sym, result_arg_params, result_aux_params = amp.convert_model(sym,
arg_params,
aux_params)

# Run dummy inference with the converted symbolic model
mod = mx.mod.Module(result_sym, data_names=["data"], label_names=["softmax_label"], context=mx.current_context())
mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]])
mod.set_params(result_arg_params, result_aux_params)
mod.forward(mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))],
label=[mx.nd.ones((1,))]))
mod.get_outputs()[0].wait_to_read()
print("Conversion and Inference completed successfully")
```



## Current limitations of AMP

- AMP's dynamic loss scaling currently supports only Gluon trainer with `update_on_kvstore=False` option set
Expand Down
35 changes: 35 additions & 0 deletions example/automatic-mixed-precision/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->

<!--- http://www.apache.org/licenses/LICENSE-2.0 -->

<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->

# Conversion of FP32 models to Mixed Precision Models


This folder contains examples for converting FP32 models to mixed precision models. The script allows for converting FP32 symbolic models or gluon models to mixed precision model.

## Basic Usages

1. AMP Model Conversion for a gluon model, casting the params wherever possible to FP16. The below script will convert the `resnet101_v1` model to Mixed Precision Model and cast params to FP16 wherever possible, load this converted model and run inference on it.

```bash
python amp_model_conversion.py --model resnet101_v1 --use-gluon-model --run-dummy-inference --cast-optional-params
```

2. AMP Model Conversion for a symbolic model, keeping the params in FP32 wherever possible (--cast-optional-params not used).

```bash
python amp_model_conversion.py --model imagenet1k-resnet-152 --run-dummy-inference
```
119 changes: 119 additions & 0 deletions example/automatic-mixed-precision/amp_model_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import os
import logging
import argparse
import mxnet as mx
from common import modelzoo
import gluoncv
from gluoncv.model_zoo import get_model
from mxnet.contrib.amp import amp
import numpy as np

def download_model(model_name, logger=None):
dir_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(dir_path, 'model')
if logger is not None:
logger.info('Downloading model {}... into path {}'.format(model_name, model_path))
return modelzoo.download_model(args.model, os.path.join(dir_path, 'model'))


def save_symbol(fname, sym, logger=None):
if logger is not None:
logger.info('Saving symbol into file at {}'.format(fname))
sym.save(fname, remove_amp_cast=False)


def save_params(fname, arg_params, aux_params, logger=None):
if logger is not None:
logger.info('Saving params into file at {}'.format(fname))
save_dict = {('arg:%s' % k): v.as_in_context(mx.cpu()) for k, v in arg_params.items()}
save_dict.update({('aux:%s' % k): v.as_in_context(mx.cpu()) for k, v in aux_params.items()})
mx.nd.save(fname, save_dict)


if __name__ == '__main__':
symbolic_models = ['imagenet1k-resnet-152',
'imagenet1k-resnet-18',
'imagenet1k-resnet-34',
'imagenet1k-resnet-50',
'imagenet1k-resnet-101',
'imagenet1k-resnext-50',
'imagenet1k-resnext-101',
'imagenet1k-resnext-101-64x4d',
'imagenet11k-place365ch-resnet-152',
'imagenet11k-place365ch-resnet-50']
gluon_models = ['resnet18_v1',
'resnet50_v1',
'resnet101_v1',
'squeezenet1.0',
'mobilenet1.0',
'mobilenetv2_1.0',
'inceptionv3']
models = symbolic_models + gluon_models

parser = argparse.ArgumentParser(description='Convert a provided FP32 model to a mixed precision model')
parser.add_argument('--model', type=str, choices=models)
parser.add_argument('--run-dummy-inference', action='store_true', default=False,
help='Will generate random input of shape (1, 3, 224, 224) '
'and run a dummy inference forward pass')
parser.add_argument('--use-gluon-model', action='store_true', default=False,
help='If enabled, will download pretrained model from Gluon-CV '
'and convert to mixed precision model ')
parser.add_argument('--cast-optional-params', action='store_true', default=False,
help='If enabled, will try to cast params to target dtype wherever possible')
args = parser.parse_args()
logging.basicConfig()
logger = logging.getLogger('logger')
logger.setLevel(logging.INFO)

if not args.use_gluon_model:
assert args.model in symbolic_models, "Please choose one of the available symbolic models: {} \
If you want to use gluon use the script with --use-gluon-model".format(symbolic_models)

prefix, epoch = download_model(model_name=args.model, logger=logger)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
result_sym, result_arg_params, result_aux_params = amp.convert_model(sym, arg_params, aux_params,
cast_optional_params=args.cast_optional_params)
sym_name = "%s-amp-symbol.json" % (prefix)
save_symbol(sym_name, result_sym, logger)
param_name = '%s-%04d.params' % (prefix + '-amp', epoch)
save_params(param_name, result_arg_params, result_aux_params, logger)
if args.run_dummy_inference:
logger.info("Running inference on the mixed precision model with dummy input, batch size: 1")
mod = mx.mod.Module(result_sym, data_names=['data'], label_names=['softmax_label'], context=mx.gpu(0))
mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]])
mod.set_params(arg_params, aux_params)
mod.forward(mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))],
label=[mx.nd.ones((1,))]))
result = mod.get_outputs()[0].asnumpy()
logger.info("Inference run successfully")
else:
assert args.model in gluon_models, "Please choose one of the available gluon models: {} \
If you want to use symbolic model instead, remove --use-gluon-model when running the script".format(gluon_models)
net = gluoncv.model_zoo.get_model(args.model, pretrained=True)
net.hybridize()
result_before1 = net.forward(mx.nd.zeros((1, 3, 224, 224)))
net.export("{}".format(args.model))
net = amp.convert_hybrid_block(net, cast_optional_params=args.cast_optional_params)
net.export("{}-amp".format(args.model), remove_amp_cast=False)
if args.run_dummy_inference:
logger.info("Running inference on the mixed precision model with dummy inputs, batch size: 1")
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0)))
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0)))
logger.info("Inference run successfully")
1 change: 1 addition & 0 deletions example/automatic-mixed-precision/common
49 changes: 49 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1764,6 +1764,55 @@ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_ha
const mx_uint num_offline, const char **offline_params,
const char *quantized_dtype, const bool calib_quantize);

/*!
* \brief Convert a symbol into a mixed precision symbol with cast operators for target dtype casting
* \param sym_handle symbol to be converted
* \param ret_sym_handle mixed precision symbol result
* \param num_args number of arguments for known dtypes
* \param arg_type_data arg types of the arguments
* \param target_dtype target_dtype for mixed precision symbol
* \param cast_optional_params whether to cast optional params to target_dtype
* \param num_target_dtype_op_names number of ops to be casted to target_dtype
* \param num_fp32_op_names number of ops to be casted to FP32
* \param num_widest_dtype_op_names number of ops to be casted to widest dtype
* \param num_conditional_fp32_op_names number of ops to be casted to FP32 based on a condition
* \param num_excluded_symbols number of symbols to be excluded from casting
* \param num_model_params number of model parameters
* \param num_widest_dtype_op_names number of ops to be casted to the widest dtype
* \param num_conditional_fp32_op_names number of ops to be cast to fp32 based on precision
* \param target_dtype_op_names op names to be casted to target_dtype
* \param fp32_op_names op names to be casted to fp32
* \param widest_dtype_op_names names to be casted to widest dtype
* \param conditional_fp32_op_names names to be casted to FP32 conditionally
* \param excluded_symbols symbol names to be excluded from casting
* \param param_names param names for conditional FP32 casting
* \param param_values param values for conditional FP32 casting
* \param arg_names argument names for which type information is provided
* \param model_param_names names for model parameters
*/
MXNET_DLL int MXReducePrecisionSymbol(SymbolHandle sym_handle,
SymbolHandle *ret_sym_handle,
mx_uint num_args,
const int* arg_type_data,
mx_uint num_ind_ptr,
const int* ind_ptr,
const int* target_dtype,
const int cast_optional_params,
const mx_uint num_target_dtype_op_names,
const mx_uint num_fp32_op_names,
const mx_uint num_widest_dtype_op_names,
const mx_uint num_conditional_fp32_op_names,
const mx_uint num_excluded_symbols,
const mx_uint num_model_params,
const char **target_dtype_op_names,
const char **fp32_op_names,
const char **widest_dtype_op_names,
const char **conditional_fp32_op_names,
const char **excluded_symbols,
const char **conditional_param_names,
const char **conditional_param_vals,
const char **model_param_names,
const char **arg_names);
/*!
* \brief Set calibration table to node attributes in the sym
* \param sym_handle symbol whose node attributes are to be set by calibration table
Expand Down
Loading