diff --git a/docs/tutorials/amp/amp_tutorial.md b/docs/tutorials/amp/amp_tutorial.md index be18929e23a9..31cd23b19ae7 100644 --- a/docs/tutorials/amp/amp_tutorial.md +++ b/docs/tutorials/amp/amp_tutorial.md @@ -31,12 +31,14 @@ For demonstration purposes we will use synthetic data loader. ```python +import os import logging import warnings import time import mxnet as mx import mxnet.gluon as gluon from mxnet import autograd +from mxnet.test_utils import download_model import gluoncv as gcv from gluoncv.model_zoo import get_model @@ -249,6 +251,46 @@ for epoch in range(1): We got 60% speed increase from 3 additional lines of code! +## Inference with AMP + +To do inference with mixed precision for a trained model in FP32, you can use the conversion APIs: `amp.convert_model` for symbolic model and `amp.convert_hybrid_block` for gluon models. The conversion APIs will take the FP32 model as input and will return a mixed precision model, which can be used to run inference. Below, we demonstrate for a gluon model and a symbolic model: 1. Conversion from FP32 model to mixed precision model 2. Run inference on the mixed precision model. + +```python +with mx.Context(mx.gpu(0)): + # Below is an example of converting a gluon hybrid block to a mixed precision block + model = get_model("resnet50_v1") + model.collect_params().initialize(ctx=mx.current_context()) + model.hybridize() + model(mx.nd.zeros((1, 3, 224, 224))) + converted_model = amp.convert_hybrid_block(model) + + # Run dummy inference with the converted gluon model + result = converted_model.forward(mx.nd.random.uniform(shape=(1, 3, 224, 224), + dtype=np.float32)) + + # Below is an example of converting a symbolic model to a mixed precision model + dir_path = os.path.dirname(os.path.realpath(__file__)) + model_path = os.path.join(dir_path, 'model') + if not os.path.isdir(model_path): + os.mkdir(model_path) + prefix, epoch = mx.test_utils.download_model("imagenet1k-resnet-18", dst_dir=model_path) + sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) + result_sym, result_arg_params, result_aux_params = amp.convert_model(sym, + arg_params, + aux_params) + + # Run dummy inference with the converted symbolic model + mod = mx.mod.Module(result_sym, data_names=["data"], label_names=["softmax_label"], context=mx.current_context()) + mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]]) + mod.set_params(result_arg_params, result_aux_params) + mod.forward(mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))], + label=[mx.nd.ones((1,))])) + mod.get_outputs()[0].wait_to_read() + print("Conversion and Inference completed successfully") +``` + + + ## Current limitations of AMP - AMP's dynamic loss scaling currently supports only Gluon trainer with `update_on_kvstore=False` option set diff --git a/example/automatic-mixed-precision/README.md b/example/automatic-mixed-precision/README.md new file mode 100644 index 000000000000..49147cd87242 --- /dev/null +++ b/example/automatic-mixed-precision/README.md @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + + +# Conversion of FP32 models to Mixed Precision Models + + +This folder contains examples for converting FP32 models to mixed precision models. The script allows for converting FP32 symbolic models or gluon models to mixed precision model. + +## Basic Usages + +1. AMP Model Conversion for a gluon model, casting the params wherever possible to FP16. The below script will convert the `resnet101_v1` model to Mixed Precision Model and cast params to FP16 wherever possible, load this converted model and run inference on it. + +```bash +python amp_model_conversion.py --model resnet101_v1 --use-gluon-model --run-dummy-inference --cast-optional-params +``` + +2. AMP Model Conversion for a symbolic model, keeping the params in FP32 wherever possible (--cast-optional-params not used). + +```bash +python amp_model_conversion.py --model imagenet1k-resnet-152 --run-dummy-inference +``` diff --git a/example/automatic-mixed-precision/amp_model_conversion.py b/example/automatic-mixed-precision/amp_model_conversion.py new file mode 100644 index 000000000000..fcc2ad69dd62 --- /dev/null +++ b/example/automatic-mixed-precision/amp_model_conversion.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import logging +import argparse +import mxnet as mx +from common import modelzoo +import gluoncv +from gluoncv.model_zoo import get_model +from mxnet.contrib.amp import amp +import numpy as np + +def download_model(model_name, logger=None): + dir_path = os.path.dirname(os.path.realpath(__file__)) + model_path = os.path.join(dir_path, 'model') + if logger is not None: + logger.info('Downloading model {}... into path {}'.format(model_name, model_path)) + return modelzoo.download_model(args.model, os.path.join(dir_path, 'model')) + + +def save_symbol(fname, sym, logger=None): + if logger is not None: + logger.info('Saving symbol into file at {}'.format(fname)) + sym.save(fname, remove_amp_cast=False) + + +def save_params(fname, arg_params, aux_params, logger=None): + if logger is not None: + logger.info('Saving params into file at {}'.format(fname)) + save_dict = {('arg:%s' % k): v.as_in_context(mx.cpu()) for k, v in arg_params.items()} + save_dict.update({('aux:%s' % k): v.as_in_context(mx.cpu()) for k, v in aux_params.items()}) + mx.nd.save(fname, save_dict) + + +if __name__ == '__main__': + symbolic_models = ['imagenet1k-resnet-152', + 'imagenet1k-resnet-18', + 'imagenet1k-resnet-34', + 'imagenet1k-resnet-50', + 'imagenet1k-resnet-101', + 'imagenet1k-resnext-50', + 'imagenet1k-resnext-101', + 'imagenet1k-resnext-101-64x4d', + 'imagenet11k-place365ch-resnet-152', + 'imagenet11k-place365ch-resnet-50'] + gluon_models = ['resnet18_v1', + 'resnet50_v1', + 'resnet101_v1', + 'squeezenet1.0', + 'mobilenet1.0', + 'mobilenetv2_1.0', + 'inceptionv3'] + models = symbolic_models + gluon_models + + parser = argparse.ArgumentParser(description='Convert a provided FP32 model to a mixed precision model') + parser.add_argument('--model', type=str, choices=models) + parser.add_argument('--run-dummy-inference', action='store_true', default=False, + help='Will generate random input of shape (1, 3, 224, 224) ' + 'and run a dummy inference forward pass') + parser.add_argument('--use-gluon-model', action='store_true', default=False, + help='If enabled, will download pretrained model from Gluon-CV ' + 'and convert to mixed precision model ') + parser.add_argument('--cast-optional-params', action='store_true', default=False, + help='If enabled, will try to cast params to target dtype wherever possible') + args = parser.parse_args() + logging.basicConfig() + logger = logging.getLogger('logger') + logger.setLevel(logging.INFO) + + if not args.use_gluon_model: + assert args.model in symbolic_models, "Please choose one of the available symbolic models: {} \ + If you want to use gluon use the script with --use-gluon-model".format(symbolic_models) + + prefix, epoch = download_model(model_name=args.model, logger=logger) + sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) + result_sym, result_arg_params, result_aux_params = amp.convert_model(sym, arg_params, aux_params, + cast_optional_params=args.cast_optional_params) + sym_name = "%s-amp-symbol.json" % (prefix) + save_symbol(sym_name, result_sym, logger) + param_name = '%s-%04d.params' % (prefix + '-amp', epoch) + save_params(param_name, result_arg_params, result_aux_params, logger) + if args.run_dummy_inference: + logger.info("Running inference on the mixed precision model with dummy input, batch size: 1") + mod = mx.mod.Module(result_sym, data_names=['data'], label_names=['softmax_label'], context=mx.gpu(0)) + mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]]) + mod.set_params(arg_params, aux_params) + mod.forward(mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))], + label=[mx.nd.ones((1,))])) + result = mod.get_outputs()[0].asnumpy() + logger.info("Inference run successfully") + else: + assert args.model in gluon_models, "Please choose one of the available gluon models: {} \ + If you want to use symbolic model instead, remove --use-gluon-model when running the script".format(gluon_models) + net = gluoncv.model_zoo.get_model(args.model, pretrained=True) + net.hybridize() + result_before1 = net.forward(mx.nd.zeros((1, 3, 224, 224))) + net.export("{}".format(args.model)) + net = amp.convert_hybrid_block(net, cast_optional_params=args.cast_optional_params) + net.export("{}-amp".format(args.model), remove_amp_cast=False) + if args.run_dummy_inference: + logger.info("Running inference on the mixed precision model with dummy inputs, batch size: 1") + result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0))) + result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0))) + logger.info("Inference run successfully") diff --git a/example/automatic-mixed-precision/common b/example/automatic-mixed-precision/common new file mode 120000 index 000000000000..cafb9140ab6a --- /dev/null +++ b/example/automatic-mixed-precision/common @@ -0,0 +1 @@ +../image-classification/common \ No newline at end of file diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index a2da6db978cb..bd30e44f910c 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1764,6 +1764,55 @@ MXNET_DLL int MXQuantizeSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_ha const mx_uint num_offline, const char **offline_params, const char *quantized_dtype, const bool calib_quantize); +/*! + * \brief Convert a symbol into a mixed precision symbol with cast operators for target dtype casting + * \param sym_handle symbol to be converted + * \param ret_sym_handle mixed precision symbol result + * \param num_args number of arguments for known dtypes + * \param arg_type_data arg types of the arguments + * \param target_dtype target_dtype for mixed precision symbol + * \param cast_optional_params whether to cast optional params to target_dtype + * \param num_target_dtype_op_names number of ops to be casted to target_dtype + * \param num_fp32_op_names number of ops to be casted to FP32 + * \param num_widest_dtype_op_names number of ops to be casted to widest dtype + * \param num_conditional_fp32_op_names number of ops to be casted to FP32 based on a condition + * \param num_excluded_symbols number of symbols to be excluded from casting + * \param num_model_params number of model parameters + * \param num_widest_dtype_op_names number of ops to be casted to the widest dtype + * \param num_conditional_fp32_op_names number of ops to be cast to fp32 based on precision + * \param target_dtype_op_names op names to be casted to target_dtype + * \param fp32_op_names op names to be casted to fp32 + * \param widest_dtype_op_names names to be casted to widest dtype + * \param conditional_fp32_op_names names to be casted to FP32 conditionally + * \param excluded_symbols symbol names to be excluded from casting + * \param param_names param names for conditional FP32 casting + * \param param_values param values for conditional FP32 casting + * \param arg_names argument names for which type information is provided + * \param model_param_names names for model parameters + */ +MXNET_DLL int MXReducePrecisionSymbol(SymbolHandle sym_handle, + SymbolHandle *ret_sym_handle, + mx_uint num_args, + const int* arg_type_data, + mx_uint num_ind_ptr, + const int* ind_ptr, + const int* target_dtype, + const int cast_optional_params, + const mx_uint num_target_dtype_op_names, + const mx_uint num_fp32_op_names, + const mx_uint num_widest_dtype_op_names, + const mx_uint num_conditional_fp32_op_names, + const mx_uint num_excluded_symbols, + const mx_uint num_model_params, + const char **target_dtype_op_names, + const char **fp32_op_names, + const char **widest_dtype_op_names, + const char **conditional_fp32_op_names, + const char **excluded_symbols, + const char **conditional_param_names, + const char **conditional_param_vals, + const char **model_param_names, + const char **arg_names); /*! * \brief Set calibration table to node attributes in the sym * \param sym_handle symbol whose node attributes are to be set by calibration table diff --git a/python/mxnet/contrib/amp/amp.py b/python/mxnet/contrib/amp/amp.py index bb3972092139..ef2f7209d946 100755 --- a/python/mxnet/contrib/amp/amp.py +++ b/python/mxnet/contrib/amp/amp.py @@ -17,21 +17,28 @@ # coding: utf-8 """Functions for enabling AMP (automatic mixed precision).""" -__all__ = ['init', 'init_trainer', 'scale_loss', 'unscale'] +__all__ = ['init', 'init_trainer', 'scale_loss', 'unscale', 'convert_model', + 'convert_hybrid_block', 'list_fp16_ops', 'list_fp32_ops', + 'list_fp16_fp32_ops', 'list_conditional_fp32_ops', + 'convert_symbol'] from types import MethodType +from array import array +import ctypes import logging import contextlib import numpy as np from ... import symbol +from ...context import gpu from ...symbol import Symbol from ...symbol import contrib as symbol_contrib from ... import ndarray -from ...ndarray import NDArray +from ...ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP from . import lists from ...gluon import trainer from ... import base +from ...base import c_str_array, SymbolHandle, check_call, _LIB, mx_uint, c_array_buf from ... import optimizer as opt from .loss_scaler import LossScaler @@ -342,3 +349,345 @@ def unscale(optimizer_or_trainer): else: raise TypeError("optimizer_or_trainer should be a Gluon Trainer or " "an optimizer, instead is %s" % type(optimizer_or_trainer)) + +def convert_symbol(sym, target_dtype="float16", target_dtype_ops=None, + fp32_ops=None, conditional_fp32_ops=None, + excluded_sym_names=None, data_names=None, + cast_optional_params=False): + """Given a symbol object representing a neural network of data type FP32 and target_dtype, + add cast layers according to the op lists (target_dtype_ops, fp32_ops, + conditional_fp32_ops) if provided, otherwise use the default + lists provided by the framework. + + Parameters + ---------- + sym : Symbol + FP32 neural network symbol + target_dtype : str or numpy, optional defaults to float16 + currently only supports float16. The target dtype indicates to add cast layers + when possible so that lower precision computation can be leveraged. + target_dtype_ops : list of strs, optional + Override the list of operator names casted to the target_dtype. + If None, uses the framework's default list to be casted to target_dtype. + fp32_ops : list of strs, optional + Override the list of operator names casted to FP32. + If None, uses the framework's default list to be casted to FP32. + conditional_fp32_ops : list of (string, string, list of string), optional + Override the list of functions to be casted to FP32. + The format of the list is + (name of the function, name of the parameter, + list of values of the parameter that make the operator to be casted to FP32) + excluded_sym_names : list of strs, optional + A list of strings that represent the names of symbols that users want to exclude + from being casted to FP16 or FP32. + data_names : list of strs, optional + A list of strings that represent input data tensor names to the model + cast_optional_params : bool, default False + Whether to cast the arg_params and aux_params that don't require to be in FP16 + because of a cast layer following it, but will reduce the computation and memory + overhead of the model if casted. + """ + assert isinstance(sym, Symbol), "First argument to convert_symbol should be Symbol" + + if target_dtype != "float16": + raise ValueError("Only target_dtype float16 is supported currently") + + if target_dtype_ops is not None: + assert isinstance(target_dtype_ops, list), "target_dtype_ops should be a list of strs" + else: + target_dtype_ops = lists.symbol.FP16_FUNCS + + if fp32_ops is not None: + assert isinstance(fp32_ops, list), "fp32_ops should be a list of strs" + else: + fp32_ops = lists.symbol.FP32_FUNCS + + if conditional_fp32_ops is not None: + assert isinstance(conditional_fp32_ops, list), "conditional_fp32_ops should be a list" + else: + conditional_fp32_ops = lists.symbol.CONDITIONAL_FP32_FUNCS + + original_conditional_op_names = [] + conditional_op_names = [] + param_names = [] + param_vals = [] + indptr = [0] + for conditional_fp32_op in conditional_fp32_ops: + assert isinstance(conditional_fp32_op[0], str) and isinstance(conditional_fp32_op[1], str) \ + and isinstance(conditional_fp32_op[2], list), "conditional_fp32_ops should be a list of " \ + "(str, str, list of str)" + param_vals += conditional_fp32_op[2] + indptr.append(len(param_vals)) + param_names.append(conditional_fp32_op[1]) + conditional_op_names.append(conditional_fp32_op[0]) + + if excluded_sym_names is not None: + assert isinstance(excluded_sym_names, list), "excluded_sym_names should be a list of strs" + else: + excluded_sym_names = [] + + for original_conditional_fp32_op in lists.symbol.CONDITIONAL_FP32_FUNCS: + original_conditional_op_names.append(original_conditional_fp32_op[0]) + + # Op lists should not have intersection + common_ops = set(target_dtype_ops) & set(fp32_ops) + assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \ + "Common ops in target_dtype_ops and fp32_ops {}".format(common_ops) + common_ops = set(target_dtype_ops) & set(conditional_op_names) + assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \ + "Common ops in target_dtype_ops and conditional_fp32_ops {}".format(common_ops) + common_ops = set(conditional_op_names) & set(fp32_ops) + assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \ + "Common ops in fp32_ops and conditional_fp32_ops {}".format(common_ops) + + combined_ops = set(target_dtype_ops + fp32_ops + conditional_op_names) + all_fp16_fp32_ops = set(lists.symbol.FP16_FUNCS + lists.symbol.FP32_FUNCS + + lists.symbol.FP16_FP32_FUNCS + original_conditional_op_names) + + illegal_ops = combined_ops - all_fp16_fp32_ops + assert not illegal_ops, '''Can only choose ops from one of the three lists + for fp16_ops and fp32_ops + 1. amp.list_fp16_ops() + 2. amp.list_fp32_ops() + 3. amp.list_fp16_fp32_ops() + 4. amp.list_conditional_fp32_ops() + Op %s not in any of them''' % (illegal_ops) + + widest_dtype_ops = lists.symbol.WIDEST_TYPE_CASTS + target_dtype = _DTYPE_NP_TO_MX[np.dtype(target_dtype).type] + + # Prepare a data_names list based on list_inputs if its not provided + # Add all names in list for the nodes in the symbol which don't have + # __dtype__ set + attr_dict = sym.attr_dict() + if data_names is None: + data_names = [] + for sym_name in sym.list_inputs(): + if not sym_name in attr_dict: + data_names.append(sym_name) + continue + if not "__dtype__" in attr_dict[sym_name]: + data_names.append(sym_name) + model_param_names = list(set(sym.list_inputs()) - set(data_names)) + + # Since assumption is that it is a FP32 model, set dtypes for all + # data_names to float32 + str_keys = [] + sdata = [] + for k in data_names: + str_keys.append(k) + sdata.append(0) + keys = c_str_array(str_keys) + + out = SymbolHandle() + check_call(_LIB.MXReducePrecisionSymbol(sym.handle, + ctypes.byref(out), + mx_uint(len(sdata)), + c_array_buf(ctypes.c_int, array('i', sdata)), + mx_uint(len(indptr)), + c_array_buf(ctypes.c_int, array('i', indptr)), + ctypes.byref(ctypes.c_int(target_dtype)), + ctypes.c_int(cast_optional_params), + mx_uint(len(target_dtype_ops)), + mx_uint(len(fp32_ops)), + mx_uint(len(widest_dtype_ops)), + mx_uint(len(conditional_op_names)), + mx_uint(len(excluded_sym_names)), + mx_uint(len(model_param_names)), + c_str_array(target_dtype_ops), + c_str_array(fp32_ops), + c_str_array(widest_dtype_ops), + c_str_array(conditional_op_names), + c_str_array(excluded_sym_names), + c_str_array(param_names), + c_str_array(param_vals), + c_str_array(model_param_names), + keys)) + return Symbol(out) + +def convert_model(sym, arg_params, aux_params, target_dtype="float16", target_dtype_ops=None, + fp32_ops=None, conditional_fp32_ops=None, excluded_sym_names=None, + cast_optional_params=False): + """API for converting a model from FP32 model to a mixed precision model. + MXNet tries to convert the FP32 model to mixed precision model by adding + cast layers using amp_cast and amp_multicast operators which can be used for inference use cases. + The decision on which cast layer to add is based on hardcoded lists for Automatic Mixed Precision + in MXNet. These lists can be overridden by the user by providing their own lists + using : targe_precision_ops, fp32_ops, widest_precision_ops, conditional_fp32_ops + + arg_params : dict + Dictionary of name to `NDArray`. + aux_params : dict + Dictionary of name to `NDArray`. + target_dtype : str + Currently only supports float16. The target dtype indicates to add cast layers + when possible so that lower precision computation can be leveraged. + target_dtype_ops : list of strs + Override the list of operator names casted to target_dtype. + If None, uses the framework's default list to be casted to target dtype. + fp32_ops : list of strs + Override the lists of operator names casted to FP32. + If None, uses the framework's default list to be casted to FP32. + widest_dtype_ops : list of strs + A list of op names provided by user which should run in widest precision among its inputs. + If None, uses the framework's default list of widest_precision_ops. + conditional_fp32_ops : list of (string, string, list of string) + Override the list of operators to be casted to FP32. + The format of the list is + (name of the function, name of the parameter, + list of values of the parameter that make the operator to be casted to + fp32) + excluded_sym_names : list of strs + A list of strings that represent the names of symbols that users want to exclude + from being executed in lower precision. + cast_optional_params : bool, default False + Whether to cast the arg_params and aux_params that don't require to be in FP16 + because of a cast layer following it, but will reduce the computation and memory + overhead of the model if casted. + """ + if excluded_sym_names is None: + excluded_sym_names = [] + if not isinstance(excluded_sym_names, list): + raise ValueError('excluded_sym_names must be a list of strings representing' + ' the names of the symbols that should not be casted,' + ' while received type %s' % str(type(excluded_sym_names))) + + if target_dtype != "float16": + raise ValueError("Only target_dtype float16 is supported currently") + + assert isinstance(sym, Symbol), "First argument to convert_model should be Symbol" + assert isinstance(arg_params, dict), "Second argument to convert_model should be a dict of name to ndarray" + assert isinstance(aux_params, dict), "Third argument to convert_model should be a dict of name to ndarray" + + param_names = list(arg_params.keys()) + list(aux_params.keys()) + + # Only pass non params as data_names, param types can be inferred + data_names = list(set(sym.list_inputs()) - set(param_names)) + + sym = convert_symbol(sym, target_dtype, target_dtype_ops, + fp32_ops, conditional_fp32_ops, + excluded_sym_names, data_names, + cast_optional_params) + + # If dtype is set for params, cast the param to that dtype + attr_dict = sym.attr_dict() + for sym_name in sym.list_arguments(): + if sym_name in attr_dict and "__dtype__" in attr_dict[sym_name]: + if attr_dict[sym_name]["__dtype__"] != "-1": + typ = _DTYPE_MX_TO_NP[int(attr_dict[sym_name]["__dtype__"])] + arg_params[sym_name] = arg_params[sym_name].astype(typ) + + for sym_name in sym.list_auxiliary_states(): + if sym_name in attr_dict and "__dtype__" in attr_dict[sym_name]: + if attr_dict[sym_name]["__dtype__"] != "-1": + typ = _DTYPE_MX_TO_NP[int(attr_dict[sym_name]["__dtype__"])] + aux_params[sym_name] = aux_params[sym_name].astype(typ) + + # Return the converted symbol and casted params + return sym, arg_params, aux_params + +def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None, + fp32_ops=None, conditional_fp32_ops=None, + excluded_sym_names=None, ctx=gpu(0), + cast_optional_params=False): + """Given a hybrid block/symbol block representing a FP32 model and a target_dtype, + return a block with mixed precision support which can be used for inference use cases. + + Parameters + ---------- + block : HybridBlock or SymbolBlock object + FP32 HybridBlock or SymbolBlock object + target_dtype : str or numpy + currently only supports fp16. The target dtype indicates to add cast layers + when possible so that lower precision computation can be leveraged. + target_precision_ops : list of strs + Override the list of operator names casted to target_dtype. + If None, uses the framework's default list to be casted to FP32. + conditional_fp32_ops : list of (str, str, list of str) + Override the list of functions to be casted to FP32. + The format of the list is + (name of the function, name of the parameter, + list of values of the parameter that make the operator to be casted to FP32 + excluded_sym_names : list of strs + A list of strings that represent the names of symbols that users want to exclude + from being quantized + ctx : Context + Context on which model parameters should live + cast_optional_params : bool, default False + Whether to cast the arg_params and aux_params that don't require to be in FP16 + because of a cast layer following it, but will reduce the computation and memory + overhead of the model if casted. + """ + from ...gluon import HybridBlock, SymbolBlock + assert isinstance(block, HybridBlock), "block input should be a HybridBlock" + if not block._cached_graph: + raise RuntimeError( + "Please first call block.hybridize() and then run forward with " + "this block at least once before calling export.") + + # Prepare inputs to pass to the convert_symbol API + inputs, sym = block._cached_graph + input_names = [] + for inp in inputs: + input_names.append(inp.name) + converted_sym = convert_symbol(sym, target_dtype, target_dtype_ops, + fp32_ops, conditional_fp32_ops, + excluded_sym_names, data_names=input_names, + cast_optional_params=cast_optional_params) + + arg_names = set(converted_sym.list_arguments()) + aux_names = set(converted_sym.list_auxiliary_states()) + arg_dict = {} + + # If dtype for the param was set in the json, cast the + # param to this dtype + attr_dict = converted_sym.attr_dict() + for name, param in block.collect_params().items(): + if name in arg_names: + arg_dict['arg:%s'%name] = param._reduce() + if name in attr_dict and "__dtype__" in attr_dict[name]: + if attr_dict[name]["__dtype__"] != "-1": + typ = _DTYPE_MX_TO_NP[int(attr_dict[name]["__dtype__"])] + arg_dict['arg:%s'%name] = arg_dict['arg:%s'%name].astype(typ) + else: + assert name in aux_names + arg_dict['aux:%s'%name] = param._reduce() + if name in attr_dict and "__dtype__" in attr_dict[name]: + if attr_dict[name]["__dtype__"] != "-1": + typ = _DTYPE_MX_TO_NP[int(attr_dict[name]["__dtype__"])] + arg_dict['aux:%s'%name] = arg_dict['aux:%s'%name].astype(typ) + + # Create a symbolblock and cast the params to the dtypes based + # on the dtype information from the converted_symbol + ret = SymbolBlock(converted_sym, inputs) + for key, param in ret.collect_params().items(): + arg_param_name = "arg:%s" % key + if arg_param_name in arg_dict and param.dtype != arg_dict[arg_param_name].dtype: + param.cast(arg_dict[arg_param_name].dtype) + + aux_param_name = "aux:%s" % key + if aux_param_name in arg_dict and param.dtype != arg_dict[aux_param_name].dtype: + param.cast(arg_dict[aux_param_name].dtype) + + ret.collect_params().load_dict(arg_dict, ctx=ctx) + return ret + +def list_fp16_ops(): + """Get the default list of FP16 ops for AMP + """ + return lists.symbol.FP16_FUNCS + +def list_fp32_ops(): + """Get the default list of FP32 ops for AMP + """ + return lists.symbol.FP32_FUNCS + +def list_fp16_fp32_ops(): + """Get the default list of ops which run in both FP16 and FP32 + """ + return lists.symbol.FP16_FP32_FUNCS + +def list_conditional_fp32_ops(): + """Get the conditional fp32 ops list + """ + return lists.symbol.CONDITIONAL_FP32_FUNCS diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 548407584715..a174d82341af 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -16,7 +16,7 @@ # under the License. # coding: utf-8 -# pylint: disable=unnecessary-pass +# pylint: disable=unnecessary-pass, too-many-lines """Neural network parameter.""" __all__ = ['DeferredInitializationError', 'Parameter', 'Constant', 'ParameterDict', 'tensor_types'] @@ -955,23 +955,51 @@ def load(self, filename, ctx=None, allow_missing=False, assert name.startswith(restore_prefix), \ "restore_prefix is '%s' but Parameters name '%s' does not start " \ "with '%s'"%(restore_prefix, name, restore_prefix) - lprefix = len(restore_prefix) ndarray_load = ndarray.load(filename) + self.load_dict(ndarray_load, ctx, allow_missing, + ignore_extra, restore_prefix, filename, cast_dtype, dtype_source) + + def load_dict(self, param_dict, ctx=None, allow_missing=False, + ignore_extra=False, restore_prefix='', filename=None, cast_dtype=False, + dtype_source="current"): + """Load parameters from dict + + Parameters + ---------- + param_dict : dict + Dictionary containing model parameters, preprended with arg: and aux: names + ctx : Context or list of Context + Context(s) initialize loaded parameters on. + allow_missing : bool, default False + Whether to silently skip loading parameters not represented in the file. + ignore_extra : bool, default False + Whether to silently ignore parameters from the file that are not + present in this ParameterDict. + restore_prefix : str, default '' + prepend prefix to names of stored parameters before loading + filename : str, default None + cast_dtype : bool, default False + Cast the data type of the NDArray loaded from the checkpoint to the dtype + provided by the Parameter if any + """ + lprefix = len(restore_prefix) loaded = [(k[4:] if k.startswith('arg:') or k.startswith('aux:') else k, v) \ - for k, v in ndarray_load.items()] if isinstance(ndarray_load, dict) else ndarray_load + for k, v in param_dict.items()] if isinstance(param_dict, dict) else param_dict arg_dict = {restore_prefix+k: v for k, v in loaded} + error_str = "file: %s" % (filename) if filename else "param_dict" if not allow_missing: for name in self.keys(): assert name in arg_dict, \ - "Parameter '%s' is missing in file '%s', which contains parameters: %s. " \ + "Parameter '%s' is missing in %s, which contains parameters: %s. " \ "Please make sure source and target networks have the same prefix."%( - name[lprefix:], filename, _brief_print_list(arg_dict.keys())) + name[lprefix:], error_str, _brief_print_list(arg_dict.keys())) for name in arg_dict: if name not in self._params: assert ignore_extra, \ - "Parameter '%s' loaded from file '%s' is not present in ParameterDict, " \ + "Parameter '%s' loaded from %s is not present in ParameterDict, " \ "choices are: %s. Set ignore_extra to True to ignore. " \ "Please make sure source and target networks have the same prefix."%( - name[lprefix:], filename, _brief_print_list(self._params.keys())) + name[lprefix:], error_str, _brief_print_list(self._params.keys())) continue - self[name]._load_init(arg_dict[name], ctx, cast_dtype=cast_dtype, dtype_source=dtype_source) + self[name]._load_init(arg_dict[name], ctx, cast_dtype=cast_dtype, + dtype_source=dtype_source) diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py index c4050699bd52..637acce317cc 100755 --- a/python/mxnet/module/executor_group.py +++ b/python/mxnet/module/executor_group.py @@ -26,6 +26,7 @@ from .. import ndarray as nd from ..io import DataDesc from ..executor_manager import _split_input_slice +from ..ndarray import _DTYPE_MX_TO_NP def _load_general(data, targets, major_axis): @@ -651,6 +652,13 @@ def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group): input_shapes.update(dict(label_shapes)) input_types = {x.name: x.dtype for x in data_shapes} + attr_dict = self.symbol.attr_dict() + + for sym_name in self.symbol.list_inputs(): + if sym_name in input_types and sym_name in attr_dict \ + and "__dtype__" in attr_dict[sym_name] and attr_dict[sym_name]["__dtype__"] != "-1": + input_types[sym_name] = _DTYPE_MX_TO_NP[int(attr_dict[sym_name]["__dtype__"])] + if label_shapes is not None: input_types.update({x.name: x.dtype for x in label_shapes}) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index d247c0fcde95..aa46a9628b81 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -29,6 +29,7 @@ import logging import bz2 import zipfile +import json from contextlib import contextmanager import numpy as np import numpy.testing as npt @@ -1521,6 +1522,72 @@ def download(url, fname=None, dirname=None, overwrite=False, retries=5): logging.info("downloaded %s into %s successfully", url, fname) return fname +def download_model(model_name, dst_dir='./', meta_info=None): + """Download a model from data.mxnet.io + + Parameters + ---------- + model_name : str + Model name to download + dst_dir : str + Destination Directory to download the model + meta_info : dict of dict + Mapping from model_name to dict of the following structure: + {'symbol': url, 'params': url} + + Returns + ------- + Two element tuple containing model_name and epoch for the params saved + """ + _base_model_url = 'http://data.mxnet.io/models/' + _default_model_info = { + 'imagenet1k-inception-bn': {'symbol':_base_model_url+'imagenet/inception-bn/Inception-BN-symbol.json', + 'params':_base_model_url+'imagenet/inception-bn/Inception-BN-0126.params'}, + 'imagenet1k-resnet-18': {'symbol':_base_model_url+'imagenet/resnet/18-layers/resnet-18-symbol.json', + 'params':_base_model_url+'imagenet/resnet/18-layers/resnet-18-0000.params'}, + 'imagenet1k-resnet-34': {'symbol':_base_model_url+'imagenet/resnet/34-layers/resnet-34-symbol.json', + 'params':_base_model_url+'imagenet/resnet/34-layers/resnet-34-0000.params'}, + 'imagenet1k-resnet-50': {'symbol':_base_model_url+'imagenet/resnet/50-layers/resnet-50-symbol.json', + 'params':_base_model_url+'imagenet/resnet/50-layers/resnet-50-0000.params'}, + 'imagenet1k-resnet-101': {'symbol':_base_model_url+'imagenet/resnet/101-layers/resnet-101-symbol.json', + 'params':_base_model_url+'imagenet/resnet/101-layers/resnet-101-0000.params'}, + 'imagenet1k-resnet-152': {'symbol':_base_model_url+'imagenet/resnet/152-layers/resnet-152-symbol.json', + 'params':_base_model_url+'imagenet/resnet/152-layers/resnet-152-0000.params'}, + 'imagenet1k-resnext-50': {'symbol':_base_model_url+'imagenet/resnext/50-layers/resnext-50-symbol.json', + 'params':_base_model_url+'imagenet/resnext/50-layers/resnext-50-0000.params'}, + 'imagenet1k-resnext-101': {'symbol':_base_model_url+'imagenet/resnext/101-layers/resnext-101-symbol.json', + 'params':_base_model_url+'imagenet/resnext/101-layers/resnext-101-0000.params'}, + 'imagenet1k-resnext-101-64x4d': + {'symbol':_base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-symbol.json', + 'params':_base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-0000.params'}, + 'imagenet11k-resnet-152': + {'symbol':_base_model_url+'imagenet-11k/resnet-152/resnet-152-symbol.json', + 'params':_base_model_url+'imagenet-11k/resnet-152/resnet-152-0000.params'}, + 'imagenet11k-place365ch-resnet-152': + {'symbol':_base_model_url+'imagenet-11k-place365-ch/resnet-152-symbol.json', + 'params':_base_model_url+'imagenet-11k-place365-ch/resnet-152-0000.params'}, + 'imagenet11k-place365ch-resnet-50': + {'symbol':_base_model_url+'imagenet-11k-place365-ch/resnet-50-symbol.json', + 'params':_base_model_url+'imagenet-11k-place365-ch/resnet-50-0000.params'}, + } + + + if meta_info is None: + meta_info = _default_model_info + meta_info = dict(meta_info) + if model_name not in meta_info: + return (None, 0) + if not os.path.isdir(dst_dir): + os.mkdir(dst_dir) + meta = dict(meta_info[model_name]) + assert 'symbol' in meta, "missing symbol url" + model_name = os.path.join(dst_dir, model_name) + mx.test_utils.download(meta['symbol'], model_name+'-symbol.json') + assert 'params' in meta, "mssing parameter file url" + mx.test_utils.download(meta['params'], model_name+'-0000.params') + return (model_name, 0) + + def get_mnist(): """Download and load the MNIST dataset @@ -2073,6 +2140,22 @@ def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='defa compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol) assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=rtol, atol=atol) + +def same_symbol_structure(sym1, sym2): + """Compare two symbols to check if they have the same computation graph structure. + Returns true if operator corresponding to a particular node id is same in both + symbols for all nodes + """ + conf = json.loads(sym1.tojson()) + nodes = conf["nodes"] + conf2 = json.loads(sym2.tojson()) + nodes2 = conf2["nodes"] + for node1, node2 in zip(nodes, nodes2): + if node1["op"] != node2["op"]: + return False + return True + + class EnvManager(object): """Environment variable setter and unsetter via with idiom""" def __init__(self, key, val): diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 4c6229ee29b0..80ae5438c20d 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -810,6 +810,210 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, API_END_HANDLE_ERROR(delete s); } +// helper function to add mapping of node_name -> dtype map +// for the given indexed graph and inferred_dtypes +static void _SetInputDTypes( + const nnvm::IndexedGraph& idx, + const nnvm::DTypeVector& inferred_dtypes, + std::unordered_map* node_name_dtype_map, + std::unordered_map* node_without_dtype_map) { + const std::string dtype_keyword = "__dtype__"; + for (uint32_t nid : idx.input_nodes()) { + const auto& node = idx[nid].source; + const auto& node_with_dtype = node->attrs.dict.find(dtype_keyword); + // input nodes classified into nodes_with_dtype, nodes_without_dtype + // This classification required because if param_names not provided + // we want to update dtypes of only those nodes which have dtypes set + // inferred_dtypes are obtained for the nodes, if unknown + // dtype is set to fp32 + if (node_with_dtype != node->attrs.dict.end()) { + if (inferred_dtypes[idx.entry_id(nid, 0)] == -1) { + (*node_name_dtype_map)[node->attrs.name] = 0; + } else { + (*node_name_dtype_map)[node->attrs.name] = + inferred_dtypes[idx.entry_id(nid, 0)]; + } + } else { + if (inferred_dtypes[idx.entry_id(nid, 0)] == -1) { + (*node_without_dtype_map)[node->attrs.name] = 0; + } else { + (*node_without_dtype_map)[node->attrs.name] = + inferred_dtypes[idx.entry_id(nid, 0)]; + } + } + } +} + +// helper function update the node dtype attrs for a vector of nodeptrs +// given the node name to dtype information and the names of model_params +// if model_params is provided the function will dtype of only model params. +// if model_params is empty, the function will dtype of all nodes which had +// a prior dtype set. +// args is a const_reference vector of NodePtrs. NodePtrs are immutable but +// the Nodes they are pointing will be mutated in this function +static void _UpdateSymDTypeAttrs( + const std::unordered_map& node_name_dtype_map, + const std::unordered_map& node_without_dtype_map, + const std::unordered_set& model_params, + const std::vector& args) { + const std::string dtype_keyword = "__dtype__"; + + // Update args to have the right dtype attrs + if (model_params.size() > 0) { + // if model params provided, set dtype only for model params + for (size_t i = 0; i < args.size(); ++i) { + const std::string& node_name = args[i]->attrs.name; + auto it_model_params = model_params.find(node_name); + auto it_with_dtype = node_name_dtype_map.find(node_name); + auto it_without_dtype = node_without_dtype_map.find(node_name); + if (it_model_params != model_params.end()) { + // need to update __dtype__ attribute if already set, else set it + if (it_with_dtype != node_name_dtype_map.end()) { + args[i]->attrs.dict[dtype_keyword] = + std::to_string(it_with_dtype->second); + } else { + CHECK(it_without_dtype != node_without_dtype_map.end()) + << "make sure all nodes without dtype have properly been added " + "in node_without_dtype_map"; + args[i]->attrs.dict[dtype_keyword] = + std::to_string(it_without_dtype->second); + } + } + } + } else { + // if model params not provided, update __dtype__ for all inputs, + // which already had it set, don't touch the rest + for (size_t i = 0; i < args.size(); ++i) { + auto it = node_name_dtype_map.find(args[i]->attrs.name); + if (it != node_name_dtype_map.end()) { + if (args[i]->attrs.dict.find(dtype_keyword) != + args[i]->attrs.dict.end()) { + args[i]->attrs.dict[dtype_keyword] = std::to_string(it->second); + } + } + } + } +} + +int MXReducePrecisionSymbol(SymbolHandle sym_handle, + SymbolHandle *ret_sym_handle, + mx_uint num_args, + const int *arg_type_data, + mx_uint num_ind_ptr, + const int* ind_ptr, + const int* target_dtype, + const int cast_optional_params, + const mx_uint num_target_dtype_op_names, + const mx_uint num_fp32_op_names, + const mx_uint num_widest_dtype_op_names, + const mx_uint num_conditional_fp32_op_names, + const mx_uint num_excluded_symbols, + const mx_uint num_model_params, + const char **target_dtype_op_names, + const char **fp32_op_names, + const char **widest_dtype_op_names, + const char **conditional_fp32_op_names, + const char **excluded_symbols, + const char **param_names, + const char **param_vals, + const char **model_param_names, + const char **arg_names) { + nnvm::Symbol *result_sym = new nnvm::Symbol(); + API_BEGIN(); + nnvm::Symbol *sym = static_cast(sym_handle); + nnvm::Graph g = Symbol2Graph(*sym); + std::unordered_set target_dtype_ops; + std::unordered_set fp32_ops; + std::unordered_set widest_dtype_ops; + std::unordered_set excluded_syms; + std::unordered_set model_params; + + // conditional_fp32_ops contains the mapping of op_name -> (map of param_name -> param_values) + // which need to be conditionally selected to be casted to FP32 + std::unordered_map>> conditional_fp32_ops; + int target_dt = *target_dtype; + + for (size_t i = 0; i < num_target_dtype_op_names; ++i) { + target_dtype_ops.emplace(target_dtype_op_names[i]); + } + for (size_t i = 0; i < num_fp32_op_names; ++i) { + fp32_ops.emplace(fp32_op_names[i]); + } + for (size_t i = 0; i < num_widest_dtype_op_names; ++i) { + widest_dtype_ops.emplace(widest_dtype_op_names[i]); + } + for (size_t i = 0; i < num_excluded_symbols; ++i) { + excluded_syms.emplace(excluded_symbols[i]); + } + for (size_t i = 0; i < num_model_params; ++i) { + model_params.emplace(model_param_names[i]); + } + + for (size_t i = 0; i < num_ind_ptr - 1; ++i) { + for (int j = ind_ptr[i]; j < ind_ptr[i + 1]; ++j) { + conditional_fp32_ops[conditional_fp32_op_names[i]][param_names[i]] + .emplace_back(std::string(param_vals[j])); + } + } + + std::unordered_map kwargs; + std::unordered_map node_name_dtype_map, node_without_dtype_map; + nnvm::DTypeVector arg_types(g.indexed_graph().input_nodes().size(), -1); + for (mx_uint i = 0; i < num_args; ++i) { + kwargs[arg_names[i]] = arg_type_data[i]; + node_name_dtype_map[arg_names[i]] = arg_type_data[i]; + } + mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType"); + + g.attrs["target_dtype_ops"] = + std::make_shared(std::move(target_dtype_ops)); + g.attrs["fp32_ops"] = std::make_shared(std::move(fp32_ops)); + g.attrs["widest_dtype_ops"] = + std::make_shared(std::move(widest_dtype_ops)); + g.attrs["conditional_fp32_ops"] = + std::make_shared(std::move(conditional_fp32_ops)); + g.attrs["excluded_syms"] = + std::make_shared(std::move(excluded_syms)); + g.attrs["target_dtype"] = std::make_shared(target_dt); + + g = ApplyPass(std::move(g), "ReducePrecision"); + // Need to run type inference since it is possible that inferred + // type of some inputs has changed + g = mxnet::exec::InferType(std::move(g), std::move(arg_types), ""); + const nnvm::DTypeVector &inferred_dtypes = + g.GetAttr("dtype"); + + g.attrs["inferred_dtypes"] = std::make_shared(std::move(inferred_dtypes)); + g.attrs["target_dtype"] = std::make_shared(target_dt); + + if (cast_optional_params) { + g = ApplyPass(std::move(g), "AMPInferUnknown"); + const nnvm::DTypeVector &inferred_dtype_result = + g.GetAttr("inferred_dtype_result"); + const nnvm::IndexedGraph &idx = g.indexed_graph(); + // set node name -> input dtype mapping using infer dtype + _SetInputDTypes(idx, inferred_dtype_result, &node_name_dtype_map, &node_without_dtype_map); + } else { + const nnvm::IndexedGraph &idx = g.indexed_graph(); + // set node name -> input dtype mapping using infer dtype + _SetInputDTypes(idx, inferred_dtypes, &node_name_dtype_map, &node_without_dtype_map); + } + + + result_sym->outputs = g.outputs; + *ret_sym_handle = result_sym; + nnvm::Symbol *ret_sym = static_cast(*ret_sym_handle); + const std::vector& args = ret_sym->ListInputs(nnvm::Symbol::kAll); + + // update symbol dtype attrs using the node name -> dtype mapping, if dtype is already set + // in the symbol, else set dtype for the model_params + _UpdateSymDTypeAttrs(node_name_dtype_map, node_without_dtype_map, model_params, args); + + API_END_HANDLE_ERROR(delete result_sym); +} + int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, const mx_uint num_layers, const char** layer_names, diff --git a/src/nnvm/amp_infer_unknown.cc b/src/nnvm/amp_infer_unknown.cc new file mode 100644 index 000000000000..1de3104d054f --- /dev/null +++ b/src/nnvm/amp_infer_unknown.cc @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2016 by Contributors + * \file low_precision_pass.cc + * \brief Use the Mixed Precision Model to infer the dtypes of + * unknown input nodes + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "../common/utils.h" +#include "../operator/tensor/amp_cast.h" + +namespace mxnet { +using nnvm::Graph; +using nnvm::NodePtr; +using nnvm::NodeEntry; +using dmlc::any; +using mxnet::op::AMPCastParam; + +// If a var node is not visited, visit it and set inferred_dtype_result as result_dtype, +// If already visited compare the result_dtype with existing inferred_dtype_result +static void CheckAndUpdateInferredDtypes( + const nnvm::DTypeVector &inferred_dtypes, const nnvm::IndexedGraph &idx, + const NodeEntry &node_entry, + mshadow::TypeFlag result_dtype, + std::unordered_map *visited_vars, + nnvm::DTypeVector *inferred_dtype_result) { + const NodePtr &input_node = node_entry.node; + if (!visited_vars->count(input_node->attrs.name)) { + if ((*inferred_dtype_result)[idx.entry_id(node_entry)] == -1) { + (*visited_vars)[input_node->attrs.name] = result_dtype; + (*inferred_dtype_result)[idx.entry_id(node_entry)] = result_dtype; + } + } else { + auto it = visited_vars->find(input_node->attrs.name); + CHECK(it != visited_vars->end()); + if (it->second != result_dtype) { + (*inferred_dtype_result)[idx.entry_id(node_entry)] = + inferred_dtypes[idx.entry_id(node_entry)]; + } + } +} + +// Graph pass to infer unknown nodes which are input nodes +// as FP16 if possible +Graph AMPInferUnknown(Graph &&src) { + const nnvm::DTypeVector &inferred_dtypes = + src.GetAttr("inferred_dtypes"); + const int target_dtype = src.GetAttr("target_dtype"); + CHECK(target_dtype == mshadow::kFloat16) + << "Only float16 target_dtype is supported yet"; + + nnvm::DTypeVector inferred_dtype_result(inferred_dtypes); + const nnvm::IndexedGraph &idx = src.indexed_graph(); + + std::unordered_map visited_vars; + + // Visits all nodes which are amp_cast and amp_multicast, + // and check if inputs to these nodes are variables. + // If input nodes are variables, set dtype for these inputs + // and check for conflicts if an input node goes to two cast nodes + DFSVisit(src.outputs, [&](const NodePtr &node) { + if (!node->is_variable()) { + std::string op_name = node->op()->name; + + if (op_name == "amp_cast") { + // for amp_cast set inferred_dtypes for input_nodes and add + // to visited_vars, if a var is being visited second time + // and already has dtype set, make sure the dtype inferred again + // is same, otherwise reset dtype to original dtype + for (const NodeEntry &node_entry : node->inputs) { + const NodePtr &input_node = node_entry.node; + if (input_node->is_variable() && + (node->attrs.dict.find("dtype") != node->attrs.dict.end())) { + const AMPCastParam ¶m = + nnvm::get(node->attrs.parsed); + CHECK(param.dtype != -1) + << "amp_cast node shouldn't have unknown dtype"; + CheckAndUpdateInferredDtypes(inferred_dtypes, idx, node_entry, + static_cast(param.dtype), + &visited_vars, &inferred_dtype_result); + } + } + } else if (op_name == "amp_multicast") { + // for amp_multicast, for non var input nodes, keep track of biggest dtype. + // If the biggest dtype is same as target_dtype, set this for the input_var nodes + // if it is not already set + mshadow::TypeFlag max_dtype = static_cast(target_dtype); + for (const NodeEntry& node_entry : node->inputs) { + const NodePtr& input_node = node_entry.node; + if (!input_node->is_variable()) { + // if one input is not a variable then don't infer the dtype of other + // input node dtypes + max_dtype = mshadow::kFloat32; + } + } + if (max_dtype == target_dtype) { + for (const NodeEntry &node_entry : node->inputs) { + const NodePtr &input_node = node_entry.node; + if (input_node->is_variable()) { + CheckAndUpdateInferredDtypes(inferred_dtypes, idx, node_entry, + max_dtype, &visited_vars, + &inferred_dtype_result); + } + } + } + } + } + }); + + Graph ret; + ret.attrs["inferred_dtype_result"] = + std::make_shared(std::move(inferred_dtype_result)); + ret.outputs = std::move(src.outputs); + return ret; +} + +NNVM_REGISTER_PASS(AMPInferUnknown) + .describe("Infer dtypes of different nodes for the mixed precision model") + .set_body(AMPInferUnknown) + .set_change_graph(true) + .provide_graph_attr("inferred_dtypes"); +} // namespace mxnet diff --git a/src/nnvm/low_precision_pass.cc b/src/nnvm/low_precision_pass.cc new file mode 100644 index 000000000000..7cd0178108f4 --- /dev/null +++ b/src/nnvm/low_precision_pass.cc @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2016 by Contributors + * \file low_precision_pass.cc + * \brief Return new graph with amp_cast and amp_multicast operators added wherever required + */ + +#include +#include +#include +#include +#include +#include +#include + +namespace mxnet { +using nnvm::Symbol; +using nnvm::Node; +using nnvm::NodePtr; +using nnvm::NodeEntry; +using nnvm::Graph; + +// create a node for operator : op_name with name : node_name +static NodePtr CreateNode(std::string op_name, std::string node_name) { + NodePtr node = Node::Create(); + node->attrs.name = node_name; + if (op_name == "nullptr") { + node->attrs.op = nullptr; + // ugly workaround because VariableParam is not exposed + node->attrs.parsed = nnvm::Symbol::CreateVariable(node->attrs.name) + .outputs[0] + .node->attrs.parsed; + } else { + node->attrs.op = Op::Get(op_name); + } + return node; +} + +static NodePtr InsertNode(std::string op_name, std::string node_name, NodePtr current, + NodeEntry previous) { + NodePtr node = CreateNode(op_name, node_name); + node->inputs.emplace_back(previous); + current->inputs.emplace_back(NodeEntry{node, 0, 0}); + return node; +} + +// get suffix for a node entry so that it can be used for amp_cast/amp_multicast node name +static std::string GetSuffix(const nnvm::NodeEntry &node_entry, + const std::unordered_map &mirror_map) { + static const auto &flist_outputs = + nnvm::Op::GetAttr("FListOutputNames"); + std::string suffix = ""; + NodePtr mirror_node = mirror_map.at(node_entry.node.get()); + if (mirror_node->op() != nullptr) { + auto list_output_names_func = flist_outputs.get(node_entry.node->op(), nullptr); + if (list_output_names_func != nullptr) { + std::vector names = list_output_names_func(node_entry.node->attrs); + suffix = "_" + names[node_entry.index]; + } else { + suffix = "_" + std::to_string(node_entry.index); + } + } + return suffix; +} + +// add amp_cast node between curr_node and input +static void AddCastNode(const nnvm::NodeEntry &e, const std::string &suffix, + const nnvm::NodeEntry &input, const std::string dtype, + nnvm::NodeEntryMap *mirror_entry_map, + NodePtr curr_node) { + NodePtr cast_node = + InsertNode("amp_cast", e.node->attrs.name + suffix + "_amp_cast_" + dtype, + curr_node, input); + cast_node->attrs.dict["dtype"] = dtype; + cast_node->op()->attr_parser(&(cast_node->attrs)); + (*mirror_entry_map)[e] = NodeEntry{std::move(cast_node), 0, e.version}; + return; +} + +// add amp_multicast node between curr_node and inputs +static void AddMultiCastNode(const std::vector &inputs, + const std::string &node_name, + const std::unordered_map &mirror_map, + NodePtr curr_node) { + NodePtr node = + CreateNode("amp_multicast", + inputs[0].node->attrs.name + node_name + "_amp_multicast"); + for (const auto &node_entry : inputs) { + NodePtr mirror_node = mirror_map.at(node_entry.node.get()); + NodeEntry mirror_entry = NodeEntry{std::move(mirror_node), node_entry.index, + node_entry.version}; + node->inputs.emplace_back(mirror_entry); + } + node->attrs.dict["num_outputs"] = std::to_string(inputs.size()); + node->op()->attr_parser(&(node->attrs)); + for (uint32_t i = 0; i < inputs.size(); ++i) { + const auto &e = inputs[i]; + curr_node->inputs.emplace_back( + NodeEntry{node, static_cast(i), e.version}); + } + return; +} + +static bool CheckConditionalFP32( + const std::unordered_map< + std::string, std::unordered_map>> + &conditional_fp32_ops, + const std::unordered_set &excluded_syms, NodePtr node) { + if (node->is_variable() || (excluded_syms.count(node->attrs.name) > 0) || + conditional_fp32_ops.count(node->op()->name) == 0) { + return false; + } else { + // Iterate through all conditional ops + auto it = conditional_fp32_ops.find(node->op()->name); + if (it != conditional_fp32_ops.end()) { + auto it_params = it->second; + // For each param name, iterate through param values to check + // if the provided param name is equal to any of the values + for (auto it_param = it_params.begin(); it_param != it_params.end(); + it_param++) { + auto param_key = node->attrs.dict.find(it_param->first); + if (param_key != node->attrs.dict.end()) { + auto it_param_vals = it_param->second; + if (std::find(it_param_vals.begin(), it_param_vals.end(), + param_key->second) != it_param_vals.end()) { + return true; + } + } + } + } + return false; + } +} + +Graph ReducePrecision(Graph &&src) { + const auto target_dtype_ops = + src.GetAttr>("target_dtype_ops"); + const auto fp32_ops = + src.GetAttr>("fp32_ops"); + const auto widest_dtype_ops = + src.GetAttr>("widest_dtype_ops"); + const auto target_dtype = src.GetAttr("target_dtype"); + const auto excluded_syms = src.GetAttr>("excluded_syms"); + const auto conditional_fp32_ops = src.GetAttr>>>( + "conditional_fp32_ops"); + + CHECK(target_dtype == mshadow::kFloat16) + << "Only float16 target_dtype is supported yet"; + + // Additional data structures to share common cast node inputs among different nodes + std::unordered_map mirror_map; + nnvm::NodeEntryMap mirror_fp32_map; + nnvm::NodeEntryMap mirror_target_dtype_map; + + // Visit nodes in a topologically sorted order + DFSVisit(src.outputs, [&](const NodePtr &node) { + NodePtr new_node = Node::Create(*node); + new_node->inputs.clear(); + + /* 1. for node which needs to run in FP32 mode, add amp_cast operators + * (to fp32) after its inputs + * 2. for node which needs to run in FP16 mode, add amp_cast operators + * (to target_dtype) after its inputs + * 3. for nodes which need to run in widest dtype among its inputs, add + * amp_multicast operators between op and its inputs + * 4. for nodes which need to run in FP32 mode, based on a specific condition, + * check the condition, and if true add amp_cast (to fp32) after its inputs + * 4. for other nodes, create copy node and add it to the mirror_map + */ + if (!node->is_variable() && fp32_ops.count(node->op()->name) > 0 && + excluded_syms.count(node->attrs.name) == 0) { + for (const auto& node_entry : node->inputs) { + if (mirror_fp32_map.count(node_entry)) { + new_node->inputs.emplace_back(mirror_fp32_map[node_entry]); + } else { + NodePtr mirror_node = mirror_map.at(node_entry.node.get()); + NodeEntry mirror_entry = NodeEntry{mirror_node, node_entry.index, node_entry.version}; + std::string suffix = GetSuffix(node_entry, mirror_map); + AddCastNode(node_entry, suffix, mirror_entry, "float32", &mirror_fp32_map, + new_node); + } + } + } else if (!node->is_variable() && + target_dtype_ops.count(node->op()->name) > 0 && + excluded_syms.count(node->attrs.name) == 0) { + for (const auto& node_entry : node->inputs) { + if (mirror_target_dtype_map.count(node_entry)) { + new_node->inputs.emplace_back(mirror_target_dtype_map[node_entry]); + } else { + NodePtr mirror_node = mirror_map.at(node_entry.node.get()); + NodeEntry mirror_entry = NodeEntry{mirror_node, node_entry.index, node_entry.version}; + std::string suffix = GetSuffix(node_entry, mirror_map); + AddCastNode(node_entry, suffix, mirror_entry, "float16", + &mirror_target_dtype_map, new_node); + } + } + } else if (!node->is_variable() && + widest_dtype_ops.count(node->op()->name) > 0 && + excluded_syms.count(node->attrs.name) == 0) { + CHECK(node->inputs.size() > 0) + << "Please check the symbol. node name: " << node->attrs.name + << "op name " << node->op()->name << " has no inputs." + << "It is likely that something went wrong during symbolic construction."; + const auto &e = node->inputs[0]; + std::string suffix = GetSuffix(e, mirror_map); + AddMultiCastNode(node->inputs, suffix, mirror_map, new_node); + } else if (CheckConditionalFP32(conditional_fp32_ops, excluded_syms, node)) { + for (const auto& node_entry : node->inputs) { + if (mirror_fp32_map.count(node_entry)) { + new_node->inputs.emplace_back(mirror_fp32_map[node_entry]); + } else { + NodePtr mirror_node = mirror_map.at(node_entry.node.get()); + NodeEntry mirror_entry = NodeEntry{mirror_node, node_entry.index, node_entry.version}; + std::string suffix = GetSuffix(node_entry, mirror_map); + AddCastNode(node_entry, suffix, mirror_entry, "float32", &mirror_fp32_map, + new_node); + } + } + } else { + for (const auto& node_entry : node->inputs) { + NodePtr mirror_node = mirror_map.at(node_entry.node.get()); + new_node->inputs.emplace_back(mirror_node, node_entry.index, node_entry.version); + } + } + mirror_map[node.get()] = std::move(new_node); + }); + + std::vector outputs; + for (const auto& e : src.outputs) { + outputs.emplace_back(mirror_map.at(e.node.get()), e.index, e.version); + } + + Graph ret; + ret.outputs = std::move(outputs); + return ret; +} + +NNVM_REGISTER_PASS(ReducePrecision) + .describe("add cast layers for low precision inference") + .set_body(ReducePrecision) + .set_change_graph(true); +} // namespace mxnet diff --git a/tests/python/gpu/test_contrib_amp.py b/tests/python/gpu/test_contrib_amp.py new file mode 100644 index 000000000000..7927cc99160b --- /dev/null +++ b/tests/python/gpu/test_contrib_amp.py @@ -0,0 +1,428 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import sys +import mxnet as mx +import numpy as np +import warnings +import collections +import ctypes +import mxnet.contrib.amp as amp +from nose.tools import assert_raises +from mxnet.test_utils import set_default_context, download_model, same_symbol_structure +from mxnet.gluon.model_zoo.vision import get_model +from mxnet.gluon import SymbolBlock +from mxnet.contrib.amp import amp +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.insert(0, os.path.join(curr_path, '../unittest')) +from common import with_seed, teardown + +def test_amp_coverage(): + conditional = [item[0] for item in amp.lists.symbol.CONDITIONAL_FP32_FUNCS] + + # Check for duplicates + for a in [amp.lists.symbol.FP16_FUNCS, + amp.lists.symbol.FP16_FP32_FUNCS, + amp.lists.symbol.FP32_FUNCS, + amp.lists.symbol.WIDEST_TYPE_CASTS, + conditional]: + ret = [item for item, count in collections.Counter(a).items() if count > 1] + assert ret == [], "Elements " + str(ret) + " are duplicated in the AMP lists." + + t = [] + for a in [amp.lists.symbol.FP16_FUNCS, + amp.lists.symbol.FP16_FP32_FUNCS, + amp.lists.symbol.FP32_FUNCS, + amp.lists.symbol.WIDEST_TYPE_CASTS, + conditional]: + t += a + ret = [item for item, count in collections.Counter(t).items() if count > 1] + assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP list." + + # Check the coverage + py_str = lambda x: x.decode('utf-8') + + plist = ctypes.POINTER(ctypes.c_char_p)() + size = ctypes.c_uint() + + mx.base._LIB.MXListAllOpNames(ctypes.byref(size), + ctypes.byref(plist)) + op_names = [] + for i in range(size.value): + s = py_str(plist[i]) + if not s.startswith("_backward") \ + and not s.startswith("_contrib_backward_"): + op_names.append(s) + + ret1 = set(op_names) - set(t) + + if ret1 != set(): + warnings.warn("Operators " + str(ret1) + " do not exist in AMP lists (in " + "python/mxnet/contrib/amp/lists/symbol.py) - please add them. " + """Please follow these guidelines for choosing a proper list: + - if your operator is not to be used in a computational graph + (e.g. image manipulation operators, optimizers) or does not have + inputs, put it in FP16_FP32_FUNCS list, + - if your operator requires FP32 inputs or is not safe to use with lower + precision, put it in FP32_FUNCS list, + - if your operator supports both FP32 and lower precision, has + multiple inputs and expects all inputs to be of the same + type, put it in WIDEST_TYPE_CASTS list, + - if your operator supports both FP32 and lower precision and has + either a single input or supports inputs of different type, + put it in FP16_FP32_FUNCS list, + - if your operator is both safe to use in lower precision and + it is highly beneficial to use it in lower precision, then + put it in FP16_FUNCS (this is unlikely for new operators) + - If you are not sure which list to choose, FP32_FUNCS is the + safest option""") + +@with_seed() +def test_amp_conversion(): + def check_amp_convert_symbol(): + x = mx.sym.var("x") + y = mx.sym.var("y") + z = mx.sym.FullyConnected(x, y, num_hidden=10, no_bias=True) + siny = mx.sym.sin(y) + res = z + siny + # Compare symbols with similar computation graphs created using convert_symbol and manually. + res_converted = amp.convert_symbol(res, target_dtype="float16", + target_dtype_ops=["FullyConnected"], + fp32_ops=["sin"]) + + x_fp16 = mx.sym.amp_cast(x, dtype="float16") + y_fp16 = mx.sym.amp_cast(y, dtype="float16") + amp_casted_siny = mx.sym.sin(mx.sym.amp_cast(y, dtype="float32")) + z = mx.sym.FullyConnected(x_fp16, y_fp16, num_hidden=10, no_bias=True) + outs = mx.sym.amp_multicast(z, amp_casted_siny, num_outputs=2) + res_expected = outs[0] + outs[1] + assert same_symbol_structure(res_converted, res_expected), \ + "convert_symbol generating wrong computation graph" + + # convert_symbol called with incorrect inputs + assert_raises(AssertionError, amp.convert_symbol, res, + target_dtype="float16", target_dtype_ops=["FullyConnected"], + fp32_ops=["elemwise_add"]) + assert_raises(AssertionError, amp.convert_symbol, res, + target_dtype="float16", target_dtype_ops=["FullyConnected"], + fp32_ops=["Activation"], + conditional_fp32_ops=[('Activation', 'act_type', ['selu'])]) + assert_raises(AssertionError, amp.convert_symbol, res, + target_dtype="float16", target_dtype_ops=["Activation"], + fp32_ops=["Activation"], + conditional_fp32_ops=[('Activation', 'act_type', ['selu'])]) + assert_raises(AssertionError, amp.convert_symbol, res, + target_dtype="float16", target_dtype_ops=["FullyConnected"], + fp32_ops=["FullyConnected"]) + + # Test for op in conditional ops with condition not satisfied + x = mx.sym.var("x") + y = mx.sym.var("y") + fc_cond = mx.sym.FullyConnected(x, y, num_hidden=10, no_bias=True) + res_converted = amp.convert_symbol(fc_cond, target_dtype="float16", + target_dtype_ops=[], + fp32_ops=["sin"], + conditional_fp32_ops=[("FullyConnected", "no_bias", ["False"])]) + + res_expected = mx.sym.FullyConnected(x, y, num_hidden=10, no_bias=True) + assert same_symbol_structure(res_converted, res_expected), \ + "convert_symbol generating wrong computation graph when conditional ops is used" + + # Test for op in conditional ops with condition satisfied + res_converted = amp.convert_symbol(fc_cond, target_dtype="float16", target_dtype_ops=[], + fp32_ops=["sin"], + conditional_fp32_ops=[("FullyConnected", "no_bias", ["True"])]) + x_fp32 = mx.sym.amp_cast(x, dtype="float32") + y_fp32 = mx.sym.amp_cast(y, dtype="float32") + res_expected = mx.sym.FullyConnected(x_fp32, y_fp32, num_hidden=10, no_bias=True) + assert same_symbol_structure(res_converted, res_expected), \ + "convert_symbol generating wrong computation graph when conditional ops used with satisfying condition" + + # Test with a real world model, default inputs for convert_symbol + dir_path = os.path.dirname(os.path.realpath(__file__)) + model_path = os.path.join(dir_path, 'model') + if not os.path.isdir(model_path): + os.mkdir(model_path) + + prefix, epoch = download_model("imagenet1k-resnet-18", dst_dir=model_path) + sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) + inputs = {} + inputs['data'] = mx.nd.ones((1, 3, 224, 224)) + inputs.update(arg_params) + converted_sym = amp.convert_symbol(sym) + exe = converted_sym.simple_bind(mx.gpu(0), data=(1, 3, 224, 224), grad_req='null') + exe.forward(is_train=False, **inputs) + exe.outputs[0].asnumpy() + + inputs2 = {} + inputs2['data'] = mx.nd.ones((1, 3, 224, 224)) + inputs2['fc1_weight'] = inputs['fc1_weight'].astype(np.float16) + inputs2['fc1_bias'] = inputs['fc1_bias'].astype(np.float16) + + # Test with a real world model, tweak inputs for convert_symbol + converted_sym = amp.convert_symbol(sym, target_dtype="float16", + target_dtype_ops=["Convolution"], data_names=["data"], + cast_optional_params=True) + converted_sym2 = amp.convert_symbol(sym, target_dtype="float16", + target_dtype_ops=["Convolution"], data_names=["data"], + cast_optional_params=False) + + exe = converted_sym.simple_bind(mx.gpu(0), data=(1, 3, 224, 224), grad_req='null') + exe2 = converted_sym2.simple_bind(mx.gpu(), data=(1, 3, 224, 224), grad_req='null') + + converted_args = converted_sym.list_arguments() + converted_auxs = converted_sym.list_auxiliary_states() + for i, key in enumerate(exe.arg_arrays): + if converted_args[i] in arg_params: + arg_params[converted_args[i]] = arg_params[converted_args[i]].astype(exe.arg_arrays[i].dtype) + for i, key in enumerate(exe.aux_arrays): + if converted_auxs[i] in aux_params: + aux_params[converted_auxs[i]] = aux_params[converted_auxs[i]].astype(exe.aux_arrays[i].dtype) + + inputs2.update(arg_params) + exe.forward(is_train=False, **inputs2) + exe.outputs[0].wait_to_read() + + inputs['fc1_weight'] = inputs['fc1_weight'].astype(np.float16) + inputs['fc1_bias'] = inputs['fc1_bias'].astype(np.float16) + exe2.forward(is_train=False, **inputs) + exe2.outputs[0].wait_to_read() + + + def check_amp_convert_model(): + # Test with real world model, default inputs for convert_model + dir_path = os.path.dirname(os.path.realpath(__file__)) + model_path = os.path.join(dir_path, 'model') + if not os.path.isdir(model_path): + os.mkdir(model_path) + prefix, epoch = download_model("imagenet1k-resnet-18", dst_dir=model_path) + + sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) + + # Test with real world model, tweak inputs for convert_model + result_sym, result_arg_params, result_aux_params = amp.convert_model(sym, + arg_params, + aux_params, + target_dtype="float16", + target_dtype_ops=["Convolution"]) + mod = mx.mod.Module(result_sym, data_names=["data"], label_names=["softmax_label"], context=mx.gpu()) + mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]]) + + mod.set_params(result_arg_params, result_aux_params) + mod.forward(mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))], + label=[mx.nd.ones((1,))])) + mod.get_outputs()[0].asnumpy() + assert mod._arg_params["stage2_unit1_conv2_weight"].dtype == np.float32 + + # Call convert_model with cast_optional_params set to True + result_sym, result_arg_params, result_aux_params = amp.convert_model(sym, + arg_params, + aux_params, + target_dtype="float16", + target_dtype_ops=["Convolution"], cast_optional_params=True) + mod = mx.mod.Module(result_sym, data_names=["data"], label_names=["softmax_label"], context=mx.gpu()) + mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]]) + mod.set_params(result_arg_params, result_aux_params) + mod.forward(mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))], + label=[mx.nd.ones((1,))])) + mod.get_outputs()[0].asnumpy() + assert mod._arg_params["stage2_unit1_conv2_weight"].dtype == np.float16 + + + def check_amp_convert_hybrid_block(): + # Test conversion for hybrid block on CPU + model_cpu = get_model("resnet50_v1") + model_cpu.collect_params().initialize(ctx=mx.cpu()) + model_cpu.hybridize() + model_cpu(mx.nd.random.uniform(0, 1, shape=(1, 3, 224, 224), ctx=mx.cpu())) + converted_model_cpu = amp.convert_hybrid_block(model_cpu) + + # Test with real world model, default inputs for convert_hybrid_block + model = get_model("resnet50_v1") + model.collect_params().initialize(ctx=mx.gpu()) + model.hybridize() + model(mx.nd.zeros((1, 3, 224, 224))) + converted_model = amp.convert_hybrid_block(model) + result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224), + dtype=np.float32)) + result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224), + dtype=np.float32)) + + # Test with real world model, tweak inputs for convert_hybrid_block + converted_model = amp.convert_hybrid_block(model, target_dtype="float16", + target_dtype_ops=["Convolution"]) + result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224), + dtype=np.float32)) + result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224), + dtype=np.float32)) + + # Check symbolic block + dir_path = os.path.dirname(os.path.realpath(__file__)) + model_path = os.path.join(dir_path, 'model') + if not os.path.isdir(model_path): + os.mkdir(model_path) + prefix, epoch = download_model("imagenet1k-resnet-18", dst_dir=model_path) + net = SymbolBlock.imports(os.path.join(model_path, "imagenet1k-resnet-18-symbol.json"), + input_names=["data", "softmax_label"], + param_file=os.path.join(model_path, "imagenet1k-resnet-18-0000.params")) + net.collect_params().reset_ctx(ctx=mx.gpu()) + net.hybridize() + net(mx.nd.zeros((1, 3, 224, 224)), mx.nd.zeros((1,))) + converted_model = amp.convert_hybrid_block(net) + result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224)), mx.nd.zeros((1,))) + result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224)), mx.nd.zeros((1,))) + + # Check symbolic block, tweaked inputs + converted_model = amp.convert_hybrid_block(net, target_dtype="float16", target_dtype_ops=["Convolution"]) + result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224)), mx.nd.zeros((1, ))) + result = converted_model.forward(mx.nd.zeros((1, 3, 224, 224)), mx.nd.zeros((1, ))) + params = converted_model.collect_params() + assert params["stage2_unit1_conv2_weight"].dtype == np.float32 + + # Pass cast_optional_params as True to convert_hybrid_block + converted_model = amp.convert_hybrid_block(net, target_dtype="float16", target_dtype_ops=["Convolution"], + cast_optional_params=True) + params = converted_model.collect_params() + assert params["stage2_unit1_conv2_weight"].dtype == np.float16 + + with mx.Context(mx.gpu(0)): + check_amp_convert_symbol() + check_amp_convert_model() + check_amp_convert_hybrid_block() + + +@with_seed() +def test_module_backward_compatibility(): + channel_num = 10 + conv_layer_filter_dims = [2, 3] + conv_layer_strides = [1, 1] + dimension = 5 + data_len = 10 + + data = mx.sym.var("data") + conv = mx.sym.Convolution(data, + num_filter=channel_num, + kernel=tuple(conv_layer_filter_dims), + stride=tuple(conv_layer_strides)) + + bn = mx.sym.BatchNorm(conv, + eps=0.001, + momentum=0.9, + fix_gamma=False, + use_global_stats=False, + output_mean_var=False, + name="conv0_batchnorm") + fc = mx.sym.FullyConnected(bn, num_hidden=10, name="fullyconnected") + mod = mx.mod.Module(fc, data_names=["data"], context=mx.gpu(0)) + mod.bind(data_shapes=[['data', (1, 3, 224, 224)]]) + mod.init_params() + + arg_params, aux_params = mod.get_params() + for param_key, param_val in arg_params.items(): + assert param_val.dtype == np.float32, "Incorrect inference type for arg_params," \ + "please check simple_bind for module executor" + for param_key, param_val in aux_params.items(): + assert param_val.dtype == np.float32, "Incorrect inference type for aux_params," \ + "please check simple_bind for module executor" + + + sym, arg_params, aux_params = amp.convert_model(mod._symbol, mod._arg_params, mod._aux_params, target_dtype_ops=["Convolution"]) + mod = mx.mod.Module(sym, data_names=["data"], context=mx.gpu(0)) + mod.bind(data_shapes=[['data', (1, 3, 224, 224)]]) + mod.set_params(arg_params, aux_params) + assert arg_params["fullyconnected_weight"].dtype == np.float16, \ + "Module API is overwriting the inferred dtype for a mixed precision model" + + +@with_seed() +def test_fp16_casting(): + data = mx.sym.var("data") + out1 = mx.sym.amp_cast(data, dtype="float16") + out2 = mx.sym.amp_cast(data, dtype="float32") + out3 = mx.sym.amp_cast(data, dtype="float16") + # When two ops from data, with different dtypes, + # data should be float32 + res = mx.sym.Group([out1, out2]) + final_res = amp.convert_symbol(res, data_names=[], cast_optional_params=True) + exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2)) + assert exe.arg_arrays[0].dtype == np.float32 + + # When two ops from data, both casted to float16, + # data should be float16 + res = mx.sym.Group([out1, out3]) + final_res = amp.convert_symbol(res, data_names=[], cast_optional_params=True) + exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2)) + assert exe.arg_arrays[0].dtype == np.float16 + + # AMP Multicast test where one node is float32, another is float16 + data = mx.sym.var("data", dtype=np.float32) + data2 = mx.sym.var("data2", dtype=np.float16) + out4 = mx.sym.amp_multicast(data, data2, num_outputs=2) + final_res = amp.convert_symbol(out4, cast_optional_params=True) + exe = final_res.simple_bind(ctx=mx.gpu(), data2=(1, 2), data=(1, 2)) + assert exe.arg_arrays[0].dtype == np.float16 + + # AMP Multicast test where two non input nodes are float16, + # and one input node is float32 + data = mx.sym.var("data", dtype=np.float32) + data2 = mx.sym.var("data2", dtype=np.float16) + data3 = mx.sym.var("data3", dtype=np.float16) + out5 = mx.sym.amp_multicast(data, + mx.sym.elemwise_add(data2, data3), + num_outputs=2) + final_res = amp.convert_symbol(out5, target_dtype_ops=[], + fp32_ops=[], cast_optional_params=True) + exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2), data3=(1, 2)) + assert exe.arg_arrays[0].dtype == np.float32 + + # AMP Multicast test where three input nodes one fp16, one fp32 + # one unknown + data = mx.sym.var("data", dtype=np.float16) + data2 = mx.sym.var("data2", dtype=np.float32) + data3 = mx.sym.var("data3") + out6 = mx.sym.amp_multicast(data, data2, data3, num_outputs=3) + final_res = amp.convert_symbol(out6, target_dtype_ops=[], + fp32_ops=[], cast_optional_params=True) + exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2), + data3=(1, 2)) + assert exe.arg_arrays[2].dtype == np.float32 + + # Input node to amp_multicast and amp_cast, if dtypes conflict + # and input node is already fp16, it should still be fp16 + data = mx.sym.var("data", dtype=np.float16) + data2 = mx.sym.var("data2", dtype=np.float32) + out7 = mx.sym.Group([mx.sym.amp_multicast(data, data2, num_outputs=2), mx.sym.amp_cast(data, dtype="float16")]) + final_res = amp.convert_symbol(out7, target_dtype_ops=[], + fp32_ops=[], cast_optional_params=True) + exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2)) + assert exe.arg_arrays[0].dtype == np.float16 + + # Input node to amp_multicast and amp_cast, if dtypes conflict + # and input node is already fp32, it should be changed to fp16 + data = mx.sym.var("data", dtype=np.float32) + data2 = mx.sym.var("data2", dtype=np.float16) + out8 = mx.sym.Group([mx.sym.amp_multicast(data, data2, num_outputs=2), mx.sym.amp_cast(data, dtype="float16")]) + final_res = amp.convert_symbol(out8, target_dtype_ops=[], + fp32_ops=[], cast_optional_params=True) + exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2)) + assert exe.arg_arrays[0].dtype == np.float16 + + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/tensorrt/test_tensorrt_lenet5.py b/tests/python/tensorrt/test_tensorrt_lenet5.py index ce88b9de3f5c..d105d6517887 100644 --- a/tests/python/tensorrt/test_tensorrt_lenet5.py +++ b/tests/python/tensorrt/test_tensorrt_lenet5.py @@ -46,7 +46,7 @@ def run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_siz # Get this value from all_test_labels # Also get classes from the dataset num_ex = 10000 - all_preds = np.zeros([num_ex, 10]) + all_preds = np.zeros([num_ex, 10], dtype=np.float32) test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) example_ct = 0 diff --git a/tests/python/unittest/test_contrib_amp.py b/tests/python/unittest/test_contrib_amp.py deleted file mode 100644 index 13048c35371e..000000000000 --- a/tests/python/unittest/test_contrib_amp.py +++ /dev/null @@ -1,85 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import mxnet as mx -import warnings -import collections -import ctypes -import mxnet.contrib.amp as amp - -def test_amp_coverage(): - conditional = [item[0] for item in amp.lists.symbol.CONDITIONAL_FP32_FUNCS] - - # Check for duplicates - for a in [amp.lists.symbol.FP16_FUNCS, - amp.lists.symbol.FP16_FP32_FUNCS, - amp.lists.symbol.FP32_FUNCS, - amp.lists.symbol.WIDEST_TYPE_CASTS, - conditional]: - ret = [item for item, count in collections.Counter(a).items() if count > 1] - assert ret == [], "Elements " + str(ret) + " are duplicated in the AMP lists." - - t = [] - for a in [amp.lists.symbol.FP16_FUNCS, - amp.lists.symbol.FP16_FP32_FUNCS, - amp.lists.symbol.FP32_FUNCS, - amp.lists.symbol.WIDEST_TYPE_CASTS, - conditional]: - t += a - ret = [item for item, count in collections.Counter(t).items() if count > 1] - assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP list." - - # Check the coverage - py_str = lambda x: x.decode('utf-8') - - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - - mx.base._LIB.MXListAllOpNames(ctypes.byref(size), - ctypes.byref(plist)) - op_names = [] - for i in range(size.value): - s = py_str(plist[i]) - if not s.startswith("_backward") \ - and not s.startswith("_contrib_backward_"): - op_names.append(s) - - ret1 = set(op_names) - set(t) - - if ret1 != set(): - warnings.warn("Operators " + str(ret1) + " do not exist in AMP lists (in " - "python/mxnet/contrib/amp/lists/symbol.py) - please add them. " - """Please follow these guidelines for choosing a proper list: - - if your operator is not to be used in a computational graph - (e.g. image manipulation operators, optimizers) or does not have - inputs, put it in FP16_FP32_FUNCS list, - - if your operator requires FP32 inputs or is not safe to use with lower - precision, put it in FP32_FUNCS list, - - if your operator supports both FP32 and lower precision, has - multiple inputs and expects all inputs to be of the same - type, put it in WIDEST_TYPE_CASTS list, - - if your operator supports both FP32 and lower precision and has - either a single input or supports inputs of different type, - put it in FP16_FP32_FUNCS list, - - if your operator is both safe to use in lower precision and - it is highly beneficial to use it in lower precision, then - put it in FP16_FUNCS (this is unlikely for new operators) - - If you are not sure which list to choose, FP32_FUNCS is the - safest option""") - -if __name__ == '__main__': - test_amp_coverage()