diff --git a/amalgamation/python/mxnet_predict.py b/amalgamation/python/mxnet_predict.py index a91d3849b0d2..48e3cd4a5145 100644 --- a/amalgamation/python/mxnet_predict.py +++ b/amalgamation/python/mxnet_predict.py @@ -25,17 +25,77 @@ import os import sys +from array import array import ctypes import logging import numpy as np +# pylint: disable= no-member +_DTYPE_NP_TO_MX = { + None: -1, + np.float32: 0, + np.float64: 1, + np.float16: 2, + np.uint8: 3, + np.int32: 4, + np.int8: 5, + np.int64: 6, +} + +_DTYPE_MX_TO_NP = { + -1: None, + 0: np.float32, + 1: np.float64, + 2: np.float16, + 3: np.uint8, + 4: np.int32, + 5: np.int8, + 6: np.int64, +} + __all__ = ["Predictor", "load_ndarray_file"] if sys.version_info[0] == 3: py_str = lambda x: x.decode('utf-8') + + def c_str_array(strings): + """Create ctypes const char ** from a list of Python strings. + + Parameters + ---------- + strings : list of string + Python strings. + + Returns + ------- + (ctypes.c_char_p * len(strings)) + A const char ** pointer that can be passed to C API. + """ + arr = (ctypes.c_char_p * len(strings))() + arr[:] = [s.encode('utf-8') for s in strings] + return arr + + else: py_str = lambda x: x + def c_str_array(strings): + """Create ctypes const char ** from a list of Python strings. + + Parameters + ---------- + strings : list of strings + Python strings. + + Returns + ------- + (ctypes.c_char_p * len(strings)) + A const char ** pointer that can be passed to C API. + """ + arr = (ctypes.c_char_p * len(strings))() + arr[:] = strings + return arr + def c_str(string): """"Convert a python string to C string.""" @@ -48,6 +108,11 @@ def c_array(ctype, values): """Create ctypes array from a python array.""" return (ctype * len(values))(*values) +def c_array_buf(ctype, buf): + """Create ctypes array from a Python buffer.""" + return (ctype * len(buf)).from_buffer(buf) + + def _find_lib_path(): """Find mxnet library.""" @@ -87,9 +152,18 @@ def _check_call(ret): if ret != 0: raise RuntimeError(py_str(_LIB.MXGetLastError())) + +def _monitor_callback_wrapper(callback): + """A wrapper for the user-defined handle.""" + def callback_handle(name, array, _): + """ ctypes function """ + callback(name, array) + return callback_handle + _LIB = _load_lib() # type definitions mx_uint = ctypes.c_uint +mx_int = ctypes.c_int mx_float = ctypes.c_float mx_float_p = ctypes.POINTER(mx_float) PredictorHandle = ctypes.c_void_p @@ -116,10 +190,13 @@ class Predictor(object): dev_id : int, optional The device id of the predictor. + + type_dict : Dict of str->numpy.dtype + Input type dictionary, name->dtype """ def __init__(self, symbol_file, param_raw_bytes, input_shapes, - dev_type="cpu", dev_id=0): + dev_type="cpu", dev_id=0, type_dict=None): dev_type = devstr2type[dev_type] indptr = [0] sdata = [] @@ -133,7 +210,26 @@ def __init__(self, symbol_file, handle = PredictorHandle() param_raw_bytes = bytearray(param_raw_bytes) ptr = (ctypes.c_char * len(param_raw_bytes)).from_buffer(param_raw_bytes) - _check_call(_LIB.MXPredCreate( + + # data types + num_provided_arg_types = 0 + # provided type argument names + provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() + # provided types + provided_arg_type_data = ctypes.POINTER(mx_uint)() + if type_dict is not None: + provided_arg_type_names = [] + provided_arg_type_data = [] + for k, v in type_dict.items(): + v = np.dtype(v).type + if v in _DTYPE_NP_TO_MX: + provided_arg_type_names.append(k) + provided_arg_type_data.append(_DTYPE_NP_TO_MX[v]) + num_provided_arg_types = mx_uint(len(provided_arg_type_names)) + provided_arg_type_names = c_str_array(provided_arg_type_names) + provided_arg_type_data = c_array_buf(ctypes.c_int, array('i', provided_arg_type_data)) + + _check_call(_LIB.MXPredCreateEx( c_str(symbol_file), ptr, len(param_raw_bytes), ctypes.c_int(dev_type), ctypes.c_int(dev_id), @@ -141,7 +237,11 @@ def __init__(self, symbol_file, c_array(ctypes.c_char_p, keys), c_array(mx_uint, indptr), c_array(mx_uint, sdata), + num_provided_arg_types, + provided_arg_type_names, + provided_arg_type_data, ctypes.byref(handle))) + self.type_dict = type_dict self.handle = handle def __del__(self): @@ -160,10 +260,18 @@ def forward(self, **kwargs): >>> predictor.forward(data=mydata) >>> out = predictor.get_output(0) """ + if self.type_dict and len(self.type_dict) != len(kwargs.items()): + raise ValueError("number of kwargs should be same as len of type_dict" \ + "Please check your forward pass inputs" \ + "or type_dict passed to Predictor instantiation") + for k, v in kwargs.items(): if not isinstance(v, np.ndarray): raise ValueError("Expect numpy ndarray as input") - v = np.asarray(v, dtype=np.float32, order='C') + if self.type_dict and k in self.type_dict: + v = np.asarray(v, dtype=self.type_dict[k], order='C') + else: + v = np.asarray(v, dtype=np.float32, order='C') _check_call(_LIB.MXPredSetInput( self.handle, c_str(k), v.ctypes.data_as(mx_float_p), @@ -218,18 +326,30 @@ def get_output(self, index): """ pdata = ctypes.POINTER(mx_uint)() ndim = mx_uint() + out_type = mx_int() _check_call(_LIB.MXPredGetOutputShape( self.handle, index, ctypes.byref(pdata), ctypes.byref(ndim))) + _check_call(_LIB.MXPredGetOutputType( + self.handle, index, + ctypes.byref(out_type))) shape = tuple(pdata[:ndim.value]) - data = np.empty(shape, dtype=np.float32) + data = np.empty(shape, dtype=_DTYPE_MX_TO_NP[out_type.value]) _check_call(_LIB.MXPredGetOutput( self.handle, mx_uint(index), data.ctypes.data_as(mx_float_p), mx_uint(data.size))) return data + def set_monitor_callback(self, callback, monitor_all=False): + cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_void_p, ctypes.c_void_p) + self._monitor_callback = cb_type(_monitor_callback_wrapper(callback)) + _check_call(_LIB.MXPredSetMonitorCallback(self.handle, + self._monitor_callback, + None, + ctypes.c_int(monitor_all))) + def load_ndarray_file(nd_bytes): """Load ndarray file and return as list of numpy array. @@ -273,4 +393,5 @@ def load_ndarray_file(nd_bytes): if len(keys) == 0 or len(keys[0]) == 0: return arrs else: - return {keys[i] : arrs[i] for i in range(len(keys))} + return {keys[i] : arrs[i] for i in range(len(keys)) + } diff --git a/include/mxnet/c_predict_api.h b/include/mxnet/c_predict_api.h index ecbbf8dfc819..18bec625f05f 100644 --- a/include/mxnet/c_predict_api.h +++ b/include/mxnet/c_predict_api.h @@ -49,6 +49,12 @@ typedef float mx_float; typedef void *PredictorHandle; /*! \brief handle to NDArray list */ typedef void *NDListHandle; +/*! \brief handle to NDArray */ +typedef void *NDArrayHandle; +/*! \brief callback used for add monitoring to nodes in the graph */ +typedef void (*PredMonitorCallback)(const char*, + NDArrayHandle, + void*); /*! * \brief Get the last error happeneed. @@ -85,6 +91,44 @@ MXNET_DLL int MXPredCreate(const char* symbol_json_str, const mx_uint* input_shape_data, PredictorHandle* out); +/*! + * \brief create a predictor + * \param symbol_json_str The JSON string of the symbol. + * \param param_bytes The in-memory raw bytes of parameter ndarray file. + * \param param_size The size of parameter ndarray file. + * \param dev_type The device type, 1: cpu, 2: gpu + * \param dev_id The device id of the predictor. + * \param num_input_nodes Number of input nodes to the net. + * For feedforward net, this is 1. + * \param input_keys The name of the input argument. + * For feedforward net, this is {"data"} + * \param input_shape_indptr Index pointer of shapes of each input node. + * The length of this array = num_input_nodes + 1. + * For feedforward net that takes 4 dimensional input, this is {0, 4}. + * \param input_shape_data A flattened data of shapes of each input node. + * For feedforward net that takes 4 dimensional input, this is the shape data. + * \param num_provided_arg_dtypes + * The length of provided_arg_dtypes. + * \param provided_arg_dtype_names + * The provided_arg_dtype_names the names of args for which dtypes are provided. + * \param provided_arg_dtypes + * The provided_arg_dtypes the dtype provided + * \param out The created predictor handle. + * \return 0 when success, -1 when failure. + */ +MXNET_DLL int MXPredCreateEx(const char* symbol_json_str, + const void* param_bytes, + int param_size, + int dev_type, int dev_id, + const mx_uint num_input_nodes, + const char** input_keys, + const mx_uint* input_shape_indptr, + const mx_uint* input_shape_data, + const mx_uint num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + PredictorHandle* out); + /*! * \brief create a predictor wich customized outputs * \param symbol_json_str The JSON string of the symbol. @@ -186,6 +230,18 @@ MXNET_DLL int MXPredGetOutputShape(PredictorHandle handle, mx_uint index, mx_uint** shape_data, mx_uint* shape_ndim); + +/*! + * \brief Get the dtype of output node. + * The returned data type is only valid before next call to MXPred function. + * \param handle The handle of the predictor. + * \param out_index The index of the output node, set to 0 if there is only one output. + * \param out_dtype The dtype of the output node + */ +MXNET_DLL int MXPredGetOutputType(PredictorHandle handle, + mx_uint out_index, + int* out_dtype); + /*! * \brief Set the input data of predictor. * \param handle The predictor handle. @@ -269,6 +325,15 @@ MXNET_DLL int MXNDListGet(NDListHandle handle, const mx_float** out_data, const mx_uint** out_shape, mx_uint* out_ndim); + +/*! + * \brief set a call back to notify the completion of operation and allow for + * additional monitoring + */ +MXNET_DLL int MXPredSetMonitorCallback(PredictorHandle handle, + PredMonitorCallback callback, + void* callback_handle, + bool monitor_all); /*! * \brief Free a MXAPINDList * \param handle The handle of the MXAPINDList. diff --git a/src/c_api/c_predict_api.cc b/src/c_api/c_predict_api.cc index 7de23ef935ef..b371fd044dc5 100644 --- a/src/c_api/c_predict_api.cc +++ b/src/c_api/c_predict_api.cc @@ -47,6 +47,9 @@ struct MXAPIPredictor { std::vector aux_arrays; // output shapes mxnet::ShapeVector out_shapes; + // output types + nnvm::DTypeVector out_dtypes; + // uint32_t buffer for output shapes std::vector out_shapes_buffer; // key to arguments @@ -88,7 +91,7 @@ int _CreatePartialOut(const char* symbol_json_str, const void* param_bytes, int param_size, int dev_type, int dev_id, - mx_uint num_input_nodes, + const mx_uint num_input_nodes, const char** input_keys, const mx_uint* input_shape_indptr, const mx_uint* input_shape_data, @@ -97,6 +100,9 @@ int _CreatePartialOut(const char* symbol_json_str, // This is used for parallel inference. int num_threads, bool lazy, + const mx_uint num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, PredictorHandle* out) { using nnvm::Symbol; @@ -135,6 +141,7 @@ int _CreatePartialOut(const char* symbol_json_str, // load the parameters std::unordered_map arg_params, aux_params; + std::unordered_map arg_types, aux_types; { std::unordered_set arg_names, aux_names; std::vector arg_names_vec = sym.ListInputNames(Symbol::kReadOnlyArgs); @@ -156,12 +163,23 @@ int _CreatePartialOut(const char* symbol_json_str, std::string name(names[i].c_str() + 4); if (aux_names.count(name) != 0) { aux_params[name] = data[i]; + aux_types[name] = data[i].dtype(); } } if (!strncmp(names[i].c_str(), "arg:", 4)) { std::string name(names[i].c_str() + 4); if (arg_names.count(name) != 0) { arg_params[name] = data[i]; + arg_types[name] = data[i].dtype(); + } + } + } + + if (num_provided_arg_dtypes > 0) { + for (mx_uint i = 0; i < num_provided_arg_dtypes; ++i) { + if (aux_types.count(provided_arg_dtype_names[i]) == 0 && + arg_types.count(provided_arg_dtype_names[i]) == 0) { + arg_types[provided_arg_dtype_names[i]] = provided_arg_dtypes[i]; } } } @@ -179,6 +197,7 @@ int _CreatePartialOut(const char* symbol_json_str, mxnet::ShapeVector out_shapes(sym.ListOutputNames().size()); mxnet::ShapeVector aux_shapes(aux_names.size()); mxnet::ShapeVector arg_shapes; + nnvm::DTypeVector result_arg_types, result_out_types, result_aux_types; std::unordered_map key2arg; for (size_t i = 0; i < arg_names.size(); ++i) { std::string key = arg_names[i]; @@ -187,6 +206,7 @@ int _CreatePartialOut(const char* symbol_json_str, try { mxnet::ShapeVector in_shapes; + nnvm::DTypeVector in_types; for (std::string key : sym.ListInputNames(Symbol::kAll)) { if (known_shape.count(key) != 0) { in_shapes.push_back(known_shape[key]); @@ -194,14 +214,38 @@ int _CreatePartialOut(const char* symbol_json_str, in_shapes.emplace_back(); } } + + for (std::string key : sym.ListInputNames(Symbol::kAll)) { + if (arg_types.count(key) != 0) { + in_types.push_back(arg_types[key]); + } else if (aux_types.count(key) != 0) { + in_types.push_back(aux_types[key]); + } else { + // if key not in arg_types or aux_types set to FP32 + in_types.push_back(0); + } + } nnvm::Graph g; g.outputs = sym.outputs; g = mxnet::exec::InferShape(std::move(g), std::move(in_shapes), "__shape__"); + g = mxnet::exec::InferType(std::move(g), std::move(in_types), "__dtype__"); bool infer_complete = (g.GetAttr("shape_num_unknown_nodes") == 0); + // This is tricky for AMP Use case, for example, with only weights input types + // cannot be inferred in AMP. Thus for AMP converted model type_dict will be + // required + bool infer_type_complete = (g.GetAttr("dtype_num_unknown_nodes") == 0); CHECK(infer_complete) << "The shape information of is not enough to get the shapes"; + CHECK(infer_type_complete) + << "The type information is not enough, please provide input arg_types " + "with provided_arg_dtype_names and provided_arg_dtypes." + "If using amalgamation python frontend you can use type_dict in Predictor API" + "to provide this information"; CopyAttr(g.indexed_graph(), g.GetAttr("shape"), &arg_shapes, &out_shapes, &aux_shapes); + CopyAttr(g.indexed_graph(), + g.GetAttr("dtype"), + &result_arg_types, &result_out_types, &result_aux_types); } catch (const mxnet::op::InferShapeError &err) { throw dmlc::Error(err.msg); } @@ -210,19 +254,31 @@ int _CreatePartialOut(const char* symbol_json_str, std::vector arg_arrays, aux_arrays; for (size_t i = 0; i < arg_shapes.size(); ++i) { - NDArray nd = NDArray(arg_shapes[i], ctx); + NDArray nd; + if (result_arg_types[i] != -1) { + nd = NDArray(arg_shapes[i], ctx, false, result_arg_types[i]); + } else { + nd = NDArray(arg_shapes[i], ctx); + } if (arg_params.count(arg_names[i]) != 0) { CopyFromTo(arg_params[arg_names[i]], &nd); } arg_arrays.push_back(nd); } + for (size_t i = 0; i < aux_shapes.size(); ++i) { - NDArray nd = NDArray(aux_shapes[i], ctx); + NDArray nd; + if (result_aux_types[i] != -1) { + nd = NDArray(aux_shapes[i], ctx, false, result_aux_types[i]); + } else { + nd = NDArray(aux_shapes[i], ctx); + } if (aux_params.count(aux_names[i]) != 0) { CopyFromTo(aux_params[aux_names[i]], &nd); } aux_arrays.push_back(nd); } + // bind for (int i = 0; i < num_threads; i++) { std::unique_ptr ret(new MXAPIPredictor()); @@ -232,6 +288,7 @@ int _CreatePartialOut(const char* symbol_json_str, ret->arg_arrays = arg_arrays; ret->aux_arrays = aux_arrays; ret->out_shapes = out_shapes; + ret->out_dtypes = result_out_types; if (!lazy) { std::map ctx_map; @@ -272,6 +329,9 @@ int MXPredCreatePartialOut(const char* symbol_json_str, output_keys, 1, false, + 0, + nullptr, + nullptr, out); } @@ -295,9 +355,44 @@ int MXPredCreate(const char* symbol_json_str, input_shape_indptr, input_shape_data, 0, - NULL, + nullptr, + 1, + false, + 0, + nullptr, + nullptr, + out); +} + +int MXPredCreateEx(const char* symbol_json_str, + const void* param_bytes, + int param_size, + int dev_type, int dev_id, + mx_uint num_input_nodes, + const char** input_keys, + const mx_uint* input_shape_indptr, + const mx_uint* input_shape_data, + const mx_uint num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + PredictorHandle* out) { + return _CreatePartialOut( + symbol_json_str, + param_bytes, + param_size, + dev_type, + dev_id, + num_input_nodes, + input_keys, + input_shape_indptr, + input_shape_data, + 0, + nullptr, 1, false, + num_provided_arg_dtypes, + provided_arg_dtype_names, + provided_arg_dtypes, out); } @@ -330,9 +425,12 @@ int MXPredCreateMultiThread(const char* symbol_json_str, input_shape_indptr, input_shape_data, 0, - NULL, + nullptr, num_threads, true, + 0, + nullptr, + nullptr, out); } @@ -421,6 +519,7 @@ int MXPredReshape(mx_uint num_input_nodes, p->exec.get())); ret->out_shapes = out_shapes; ret->out_arrays = ret->exec->outputs(); + ret->out_dtypes = p->out_dtypes; } *out = ret.release(); API_END(); @@ -444,6 +543,21 @@ int MXPredGetOutputShape(PredictorHandle handle, API_END(); } +int MXPredGetOutputType(PredictorHandle handle, + mx_uint out_index, + int* out_dtype) { + MXAPIPredictor* p = static_cast(handle); + API_BEGIN(); + CHECK_LT(out_index, p->out_arrays.size()) + << "Index exceed number of outputs, provided out_index should be less than " + << p->out_arrays.size(); + + const int s = p->out_dtypes[out_index]; + CHECK_GE(s, 0); + out_dtype[out_index] = s; + API_END(); +} + int MXPredSetInput(PredictorHandle handle, const char* key, const mx_float* data, @@ -543,6 +657,22 @@ int MXNDListGet(NDListHandle handle, API_END(); } +int MXPredSetMonitorCallback(PredictorHandle handle, + PredMonitorCallback callback, + void* callback_handle, + bool monitor_all) { + MXAPIPredictor* p = static_cast(handle); + API_BEGIN(); + PredMonitorCallback callback_temp = callback; + void* callback_handle_temp = callback_handle; + std::function clbk + = [callback_temp, callback_handle_temp](const char* name, void* handle) { + callback_temp(name, handle, callback_handle_temp); + }; + p->exec->SetMonitorCallback(clbk, monitor_all); + API_END(); +} + int MXNDListFree(NDListHandle handle) { API_BEGIN(); delete static_cast(handle); diff --git a/tests/python/gpu/test_predictor.py b/tests/python/gpu/test_predictor.py new file mode 100644 index 000000000000..4838a76c7cb1 --- /dev/null +++ b/tests/python/gpu/test_predictor.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import print_function +import sys, os +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.append(os.path.join(curr_path, "../../../amalgamation/python/")) +from mxnet_predict import Predictor, load_ndarray_file + +import ctypes +import numpy as np +import mxnet as mx +import mxnet.ndarray as nd +from mxnet.ndarray import NDArray +from mxnet import gluon +from mxnet.test_utils import assert_almost_equal, download_model +from mxnet.contrib.amp import amp +from mxnet.base import NDArrayHandle, py_str +sys.path.insert(0, os.path.join(curr_path, '../unittest')) +from common import setup_module, with_seed, teardown + +@with_seed() +def test_predictor_with_dtype(): + prefix = 'test_predictor_simple_dense' + symbol_file = "%s-symbol.json" % prefix + param_file = "%s-0000.params" % prefix + + input1 = np.random.uniform(size=(1, 3)) + input1 = input1.astype(np.float16) + + block = mx.gluon.nn.HybridSequential() + block.add(mx.gluon.nn.Dense(7)) + block.add(mx.gluon.nn.Dense(3)) + block.cast(np.float16) + block.hybridize() + block.initialize(ctx=mx.gpu(0)) + tmp = mx.nd.array(input1, dtype=np.float16, ctx=mx.gpu(0)) + out1 = block.forward(tmp) + block.export(prefix) + + predictor = Predictor(open(symbol_file, "r").read(), + open(param_file, "rb").read(), + {"data": input1.shape}, + dev_type="gpu", + dev_id=0, + type_dict={"data": input1.dtype}) + predictor.forward(data=input1) + predictor_out1 = predictor.get_output(0) + + assert_almost_equal(out1.asnumpy(), predictor_out1, rtol=1e-5, atol=1e-6) + +def compare_module_cpredict(result_sym, result_arg_params, result_aux_params, monitor_callback=False): + # Dummmy inputs + input1 = np.ones((1, 3, 224, 224)) + input1 = input1.astype(np.float32) + nd_dict = {} + def pred_mon_callback(name, arr): + nd_dict[name] = arr + mod = mx.mod.Module(result_sym, data_names=["data"], label_names=["softmax_label"], context=mx.gpu()) + mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]], for_training=False) + mod.set_params(result_arg_params, result_aux_params) + mod.forward(mx.io.DataBatch(data=[mx.nd.array(input1, ctx=mx.gpu())], + label=[mx.nd.ones((1,), ctx=mx.gpu())])) + prefix = "test_predictor_amp" + mod.save_checkpoint(prefix, 0, remove_amp_cast=False) + sym_file = "{}-symbol.json".format(prefix) + params_file = "{}-0000.params".format(prefix) + predictor = Predictor(open(sym_file, "r").read(), + open(params_file, "rb").read(), + {'data': (1, 3, 224, 224), + 'softmax_label': (1,)}, + dev_type="gpu", + dev_id=0) + if monitor_callback: + predictor.set_monitor_callback(pred_mon_callback, monitor_all=True) + predictor.forward(data=input1, softmax_label=mx.nd.ones((1,)).asnumpy()) + predictor_out1 = predictor.get_output(0) + if monitor_callback: + assert len(nd_dict) > 0, "Callback not called" + assert_almost_equal(mod.get_outputs()[0].asnumpy(), predictor_out1, atol=1e-1, rtol=1e-1) + + +@with_seed() +def test_predictor_amp(): + dir_path = os.path.dirname(os.path.realpath(__file__)) + model_path = os.path.join(dir_path, 'model') + if not os.path.isdir(model_path): + os.mkdir(model_path) + prefix, epoch = download_model("imagenet1k-resnet-18", dst_dir=model_path) + + sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) + + + # Convert model to mixed precision model, params in FP32 + result_sym, result_arg_params, result_aux_params = amp.convert_model(sym, + arg_params, + aux_params, + target_dtype="float16", + target_dtype_ops=["Convolution"]) + compare_module_cpredict(result_sym, result_arg_params, result_aux_params) + + # Convert model to mixed precision model, params in FP16 + result_sym, result_arg_params, result_aux_params = amp.convert_model(sym, + arg_params, + aux_params, + target_dtype="float16", + target_dtype_ops=["Convolution"], + cast_optional_params=True) + compare_module_cpredict(result_sym, result_arg_params, result_aux_params, monitor_callback=True) + + +if __name__ == '__main__': + import nose + nose.runmodule()