diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 4df97e594a01..24be081344aa 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -72,21 +72,31 @@ np.int8: 5, np.int64: 6, np.bool_: 7, + np.int16: 8, + np.uint16 : 9, + np.uint32 : 10, + np.uint64 : 11, np.dtype([('bfloat16', np.uint16)]): 12, } -_DTYPE_MX_TO_NP = { - -1: None, - 0: np.float32, - 1: np.float64, - 2: np.float16, - 3: np.uint8, - 4: np.int32, - 5: np.int8, - 6: np.int64, - 7: np.bool_, - 12: np.dtype([('bfloat16', np.uint16)]), -} +def _register_platform_dependent_mx_dtype(): + """Register platform dependent types to the fixed size counterparts.""" + kind_map = {'i': 'int', 'u': 'uint', 'f': 'float'} + for np_type in [ + np.byte, np.ubyte, np.short, np.ushort, np.intc, np.uintc, np.int_, + np.uint, np.longlong, np.ulonglong, np.half, np.float16, np.single, + np.double, np.longdouble]: + dtype = np.dtype(np_type) + kind, size = dtype.kind, dtype.itemsize + bits = size * 8 + fixed_size_dtype = np + fixed_dtype = getattr(np, kind_map[kind]+str(bits)) + if fixed_dtype in _DTYPE_NP_TO_MX: + _DTYPE_NP_TO_MX[np_type] = _DTYPE_NP_TO_MX[fixed_dtype] +_register_platform_dependent_mx_dtype() + + +_DTYPE_MX_TO_NP = {v: k for k, v in _DTYPE_NP_TO_MX.items()} _STORAGE_TYPE_STR_TO_ID = { 'undefined': _STORAGE_TYPE_UNDEFINED, diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 17c549193324..5274408e4403 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -185,7 +185,8 @@ def _as_mx_np_array(object, ctx=None): if isinstance(object, ndarray): return object elif isinstance(object, _np.ndarray): - return array(object, dtype=object.dtype, ctx=ctx) + np_dtype = _np.dtype(object.dtype).type + return array(object, dtype=np_dtype, ctx=ctx) elif isinstance(object, (integer_types, numeric_types)): return object elif isinstance(object, (list, tuple)):