Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.5.x] FP16 Support for C Predict API (#15245) #16027

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 126 additions & 5 deletions amalgamation/python/mxnet_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,77 @@

import os
import sys
from array import array
import ctypes
import logging
import numpy as np

# pylint: disable= no-member
_DTYPE_NP_TO_MX = {
None: -1,
np.float32: 0,
np.float64: 1,
np.float16: 2,
np.uint8: 3,
np.int32: 4,
np.int8: 5,
np.int64: 6,
}

_DTYPE_MX_TO_NP = {
-1: None,
0: np.float32,
1: np.float64,
2: np.float16,
3: np.uint8,
4: np.int32,
5: np.int8,
6: np.int64,
}

__all__ = ["Predictor", "load_ndarray_file"]

if sys.version_info[0] == 3:
py_str = lambda x: x.decode('utf-8')

def c_str_array(strings):
"""Create ctypes const char ** from a list of Python strings.

Parameters
----------
strings : list of string
Python strings.

Returns
-------
(ctypes.c_char_p * len(strings))
A const char ** pointer that can be passed to C API.
"""
arr = (ctypes.c_char_p * len(strings))()
arr[:] = [s.encode('utf-8') for s in strings]
return arr


else:
py_str = lambda x: x

def c_str_array(strings):
"""Create ctypes const char ** from a list of Python strings.

Parameters
----------
strings : list of strings
Python strings.

Returns
-------
(ctypes.c_char_p * len(strings))
A const char ** pointer that can be passed to C API.
"""
arr = (ctypes.c_char_p * len(strings))()
arr[:] = strings
return arr


def c_str(string):
""""Convert a python string to C string."""
Expand All @@ -48,6 +108,11 @@ def c_array(ctype, values):
"""Create ctypes array from a python array."""
return (ctype * len(values))(*values)

def c_array_buf(ctype, buf):
"""Create ctypes array from a Python buffer."""
return (ctype * len(buf)).from_buffer(buf)



def _find_lib_path():
"""Find mxnet library."""
Expand Down Expand Up @@ -87,9 +152,18 @@ def _check_call(ret):
if ret != 0:
raise RuntimeError(py_str(_LIB.MXGetLastError()))


def _monitor_callback_wrapper(callback):
"""A wrapper for the user-defined handle."""
def callback_handle(name, array, _):
""" ctypes function """
callback(name, array)
return callback_handle

_LIB = _load_lib()
# type definitions
mx_uint = ctypes.c_uint
mx_int = ctypes.c_int
mx_float = ctypes.c_float
mx_float_p = ctypes.POINTER(mx_float)
PredictorHandle = ctypes.c_void_p
Expand All @@ -116,10 +190,13 @@ class Predictor(object):

dev_id : int, optional
The device id of the predictor.

type_dict : Dict of str->numpy.dtype
Input type dictionary, name->dtype
"""
def __init__(self, symbol_file,
param_raw_bytes, input_shapes,
dev_type="cpu", dev_id=0):
dev_type="cpu", dev_id=0, type_dict=None):
dev_type = devstr2type[dev_type]
indptr = [0]
sdata = []
Expand All @@ -133,15 +210,38 @@ def __init__(self, symbol_file,
handle = PredictorHandle()
param_raw_bytes = bytearray(param_raw_bytes)
ptr = (ctypes.c_char * len(param_raw_bytes)).from_buffer(param_raw_bytes)
_check_call(_LIB.MXPredCreate(

# data types
num_provided_arg_types = 0
# provided type argument names
provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)()
# provided types
provided_arg_type_data = ctypes.POINTER(mx_uint)()
if type_dict is not None:
provided_arg_type_names = []
provided_arg_type_data = []
for k, v in type_dict.items():
v = np.dtype(v).type
if v in _DTYPE_NP_TO_MX:
provided_arg_type_names.append(k)
provided_arg_type_data.append(_DTYPE_NP_TO_MX[v])
num_provided_arg_types = mx_uint(len(provided_arg_type_names))
provided_arg_type_names = c_str_array(provided_arg_type_names)
provided_arg_type_data = c_array_buf(ctypes.c_int, array('i', provided_arg_type_data))

_check_call(_LIB.MXPredCreateEx(
c_str(symbol_file),
ptr, len(param_raw_bytes),
ctypes.c_int(dev_type), ctypes.c_int(dev_id),
mx_uint(len(indptr) - 1),
c_array(ctypes.c_char_p, keys),
c_array(mx_uint, indptr),
c_array(mx_uint, sdata),
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
ctypes.byref(handle)))
self.type_dict = type_dict
self.handle = handle

def __del__(self):
Expand All @@ -160,10 +260,18 @@ def forward(self, **kwargs):
>>> predictor.forward(data=mydata)
>>> out = predictor.get_output(0)
"""
if self.type_dict and len(self.type_dict) != len(kwargs.items()):
raise ValueError("number of kwargs should be same as len of type_dict" \
"Please check your forward pass inputs" \
"or type_dict passed to Predictor instantiation")

for k, v in kwargs.items():
if not isinstance(v, np.ndarray):
raise ValueError("Expect numpy ndarray as input")
v = np.asarray(v, dtype=np.float32, order='C')
if self.type_dict and k in self.type_dict:
v = np.asarray(v, dtype=self.type_dict[k], order='C')
else:
v = np.asarray(v, dtype=np.float32, order='C')
_check_call(_LIB.MXPredSetInput(
self.handle, c_str(k),
v.ctypes.data_as(mx_float_p),
Expand Down Expand Up @@ -218,18 +326,30 @@ def get_output(self, index):
"""
pdata = ctypes.POINTER(mx_uint)()
ndim = mx_uint()
out_type = mx_int()
_check_call(_LIB.MXPredGetOutputShape(
self.handle, index,
ctypes.byref(pdata),
ctypes.byref(ndim)))
_check_call(_LIB.MXPredGetOutputType(
self.handle, index,
ctypes.byref(out_type)))
shape = tuple(pdata[:ndim.value])
data = np.empty(shape, dtype=np.float32)
data = np.empty(shape, dtype=_DTYPE_MX_TO_NP[out_type.value])
_check_call(_LIB.MXPredGetOutput(
self.handle, mx_uint(index),
data.ctypes.data_as(mx_float_p),
mx_uint(data.size)))
return data

def set_monitor_callback(self, callback, monitor_all=False):
cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_void_p, ctypes.c_void_p)
self._monitor_callback = cb_type(_monitor_callback_wrapper(callback))
_check_call(_LIB.MXPredSetMonitorCallback(self.handle,
self._monitor_callback,
None,
ctypes.c_int(monitor_all)))


def load_ndarray_file(nd_bytes):
"""Load ndarray file and return as list of numpy array.
Expand Down Expand Up @@ -273,4 +393,5 @@ def load_ndarray_file(nd_bytes):
if len(keys) == 0 or len(keys[0]) == 0:
return arrs
else:
return {keys[i] : arrs[i] for i in range(len(keys))}
return {keys[i] : arrs[i] for i in range(len(keys))
}
65 changes: 65 additions & 0 deletions include/mxnet/c_predict_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ typedef float mx_float;
typedef void *PredictorHandle;
/*! \brief handle to NDArray list */
typedef void *NDListHandle;
/*! \brief handle to NDArray */
typedef void *NDArrayHandle;
/*! \brief callback used for add monitoring to nodes in the graph */
typedef void (*PredMonitorCallback)(const char*,
NDArrayHandle,
void*);

/*!
* \brief Get the last error happeneed.
Expand Down Expand Up @@ -85,6 +91,44 @@ MXNET_DLL int MXPredCreate(const char* symbol_json_str,
const mx_uint* input_shape_data,
PredictorHandle* out);

/*!
* \brief create a predictor
* \param symbol_json_str The JSON string of the symbol.
* \param param_bytes The in-memory raw bytes of parameter ndarray file.
* \param param_size The size of parameter ndarray file.
* \param dev_type The device type, 1: cpu, 2: gpu
* \param dev_id The device id of the predictor.
* \param num_input_nodes Number of input nodes to the net.
* For feedforward net, this is 1.
* \param input_keys The name of the input argument.
* For feedforward net, this is {"data"}
* \param input_shape_indptr Index pointer of shapes of each input node.
* The length of this array = num_input_nodes + 1.
* For feedforward net that takes 4 dimensional input, this is {0, 4}.
* \param input_shape_data A flattened data of shapes of each input node.
* For feedforward net that takes 4 dimensional input, this is the shape data.
* \param num_provided_arg_dtypes
* The length of provided_arg_dtypes.
* \param provided_arg_dtype_names
* The provided_arg_dtype_names the names of args for which dtypes are provided.
* \param provided_arg_dtypes
* The provided_arg_dtypes the dtype provided
* \param out The created predictor handle.
* \return 0 when success, -1 when failure.
*/
MXNET_DLL int MXPredCreateEx(const char* symbol_json_str,
const void* param_bytes,
int param_size,
int dev_type, int dev_id,
const mx_uint num_input_nodes,
const char** input_keys,
const mx_uint* input_shape_indptr,
const mx_uint* input_shape_data,
const mx_uint num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
PredictorHandle* out);

/*!
* \brief create a predictor wich customized outputs
* \param symbol_json_str The JSON string of the symbol.
Expand Down Expand Up @@ -186,6 +230,18 @@ MXNET_DLL int MXPredGetOutputShape(PredictorHandle handle,
mx_uint index,
mx_uint** shape_data,
mx_uint* shape_ndim);

/*!
* \brief Get the dtype of output node.
* The returned data type is only valid before next call to MXPred function.
* \param handle The handle of the predictor.
* \param out_index The index of the output node, set to 0 if there is only one output.
* \param out_dtype The dtype of the output node
*/
MXNET_DLL int MXPredGetOutputType(PredictorHandle handle,
mx_uint out_index,
int* out_dtype);

/*!
* \brief Set the input data of predictor.
* \param handle The predictor handle.
Expand Down Expand Up @@ -269,6 +325,15 @@ MXNET_DLL int MXNDListGet(NDListHandle handle,
const mx_float** out_data,
const mx_uint** out_shape,
mx_uint* out_ndim);

/*!
* \brief set a call back to notify the completion of operation and allow for
* additional monitoring
*/
MXNET_DLL int MXPredSetMonitorCallback(PredictorHandle handle,
PredMonitorCallback callback,
void* callback_handle,
bool monitor_all);
/*!
* \brief Free a MXAPINDList
* \param handle The handle of the MXAPINDList.
Expand Down
Loading