Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Aug 7, 2019
1 parent 275b063 commit 345c522
Show file tree
Hide file tree
Showing 14 changed files with 86 additions and 88 deletions.
4 changes: 2 additions & 2 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,9 @@ inline int32_t Context::GetGPUCount() {
}
int32_t count;
cudaError_t e = cudaGetDeviceCount(&count);
// TODO(junwu): Remove e == 35
// TODO(junwu): Remove e == cudaErrorInsufficientDriver
// This is skipped for working around wheel build system with older CUDA driver.
if (e == cudaErrorNoDevice || e == 35) {
if (e == cudaErrorNoDevice || e == cudaErrorInsufficientDriver) {
return 0;
}
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);
Expand Down
7 changes: 0 additions & 7 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,6 @@ def __init__(self, sym, flags=()):
def __del__(self):
check_call(_LIB.MXFreeCachedOp(self.handle))

def _is_from_np_compat_op(self, idx):
"""Check if the CachedOp's idx-th output is directly from a numpy op."""
is_from_np_op = ctypes.c_int(0)
check_call(_LIB.MXIsCachedOpOutputFromNumpyCompatOp(self.handle, ctypes.c_int(idx),
ctypes.byref(is_from_np_op)))
return is_from_np_op.value != 0

def __call__(self, *args, **kwargs):
"""ctypes implementation of imperative invoke wrapper"""
out = kwargs.pop('out', None)
Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/_ctypes/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..base import SymbolHandle
from ..base import check_call

# The symbol class to be used (Cython or Ctypes)
_symbol_cls = None
_np_symbol_cls = None

Expand Down Expand Up @@ -117,7 +118,7 @@ def _set_symbol_class(cls):


def _set_np_symbol_class(cls):
"""Set the symbolic class to be cls"""
"""Set the numpy-compatible symbolic class to be cls"""
global _np_symbol_cls
_np_symbol_cls = cls

Expand Down
8 changes: 1 addition & 7 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,13 +746,6 @@ def write_all_str(module_file, module_all_list):
check_call(_LIB.MXLoadTVMOp(c_str(_LIB_TVM_OP[0])))


def _sanity_check_params(func_name, unsupported_params, param_dict):
for param_name in unsupported_params:
if param_name in param_dict:
raise NotImplementedError("function {} does not support parameter {}"
.format(func_name, param_name))


_NP_OP_PREFIX = '_np_'
_NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_']

Expand All @@ -768,6 +761,7 @@ def _is_np_op(op_name):


def _get_op_submodule_name(op_name, op_name_prefix, submodule_name_list):
"""Get the submodule name of a specific op"""
assert op_name.startswith(op_name_prefix)
for submodule_name in submodule_name_list:
if op_name[len(op_name_prefix):].startswith(submodule_name):
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from .utils import _indent, _brief_print_list, HookHandle
from .utils import _check_same_symbol_type, _check_all_np_ndarrays
from .. import numpy_extension as _mx_npx
from .. import numpy as _mx_np, numpy_extension as _mx_npx
from .. import numpy as _mx_np
from .. util import is_np_array, np_shape, np_array


Expand Down
1 change: 0 additions & 1 deletion python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# coding: utf-8
# pylint: disable=ungrouped-imports
"""Dataset generator."""
from __future__ import absolute_import
__all__ = ['DataLoader']

import pickle
Expand Down
1 change: 0 additions & 1 deletion python/mxnet/gluon/nn/conv_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def __init__(self, channels, kernel_size, strides, padding, dilation,
self._kwargs['adj'] = adj

dshape = [0]*(len(kernel_size) + 2)

dshape[layout.find('N')] = 1
dshape[layout.find('C')] = in_channels
wshapes = _infer_weight_shape(op_name, dshape, self._kwargs)
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def _check_same_symbol_type(symbols):
the symbols."""
from ..symbol.numpy import _Symbol as np_symbol
from ..symbol import Symbol as nd_symbol
is_np_sym = bool(isinstance(symbols[0], np_symbol))
is_np_sym = isinstance(symbols[0], np_symbol)
for s in symbols[1:]:
if is_np_sym != isinstance(s, np_symbol):
raise TypeError('Found both classic symbol (mx.sym.Symbol) and numpy symbol '
Expand Down
1 change: 0 additions & 1 deletion python/mxnet/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from . import registry
from . import ndarray


# inherit str for backward compatibility
class InitDesc(str):
"""
Expand Down
2 changes: 0 additions & 2 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,6 @@ def __abs__(self):

def __add__(self, other):
"""x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """
# other may be the type of mxnet.numpy.ndarray
return add(self, other)

def __iadd__(self, other):
Expand All @@ -248,7 +247,6 @@ def __radd__(self, other):

def __sub__(self, other):
"""x.__sub__(y) <=> x-y <=> mx.nd.subtract(x, y) """
# other may be the type of mxnet.numpy.ndarray
return subtract(self, other)

def __isub__(self, other):
Expand Down
24 changes: 15 additions & 9 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
from __future__ import absolute_import
import numpy as _np
from ...base import numeric_types
from ...util import _sanity_check_params, set_module
from ...util import set_module
from ...context import current_context
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power']


@set_module('mxnet.ndarray.numpy')
def zeros(shape, dtype=_np.float32, **kwargs):
def zeros(shape, dtype=_np.float32, order='C', ctx=None):
"""Return a new array of given shape and type, filled with zeros.
This function currently only supports storing multi-dimensional data
in row-major (C-style).
Expand All @@ -43,6 +43,9 @@ def zeros(shape, dtype=_np.float32, **kwargs):
behavior is different from NumPy's `ones` function where `float64`
is the default value, because `float32` is considered as the default
data type in deep learning.
order : {'C'}, optional, default: 'C'
How to store multi-dimensional data in memory, currently only row-major
(C-style) is supported.
ctx : Context, optional
An optional device context (default is the current default context).
Expand All @@ -51,16 +54,16 @@ def zeros(shape, dtype=_np.float32, **kwargs):
out : ndarray
Array of zeros with the given shape, dtype, and ctx.
"""
_sanity_check_params('zeros', ['order'], kwargs)
ctx = kwargs.pop('ctx', current_context())
if order != 'C':
raise NotImplementedError
if ctx is None:
ctx = current_context()
dtype = _np.float32 if dtype is None else dtype
return _npi.zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
return _npi.zeros(shape=shape, ctx=ctx, dtype=dtype)


@set_module('mxnet.ndarray.numpy')
def ones(shape, dtype=None, **kwargs):
def ones(shape, dtype=_np.float32, order='C', ctx=None):
"""Return a new array of given shape and type, filled with ones.
This function currently only supports storing multi-dimensional data
in row-major (C-style).
Expand All @@ -74,6 +77,9 @@ def ones(shape, dtype=None, **kwargs):
behavior is different from NumPy's `ones` function where `float64`
is the default value, because `float32` is considered as the default
data type in deep learning.
order : {'C'}, optional, default: 'C'
How to store multi-dimensional data in memory, currently only row-major
(C-style) is supported.
ctx : Context, optional
An optional device context (default is the current default context).
Expand All @@ -82,12 +88,12 @@ def ones(shape, dtype=None, **kwargs):
out : ndarray
Array of zeros with the given shape, dtype, and ctx.
"""
_sanity_check_params('zeros', ['order'], kwargs)
ctx = kwargs.pop('ctx', current_context())
if order != 'C':
raise NotImplementedError
if ctx is None:
ctx = current_context()
dtype = _np.float32 if dtype is None else dtype
return _npi.ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)
return _npi.ones(shape=shape, ctx=ctx, dtype=dtype)


#pylint: disable= too-many-arguments, no-member, protected-access
Expand Down
58 changes: 29 additions & 29 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
return hdl


# Have to use 0 as default value for stype since plylint does not allow
# Have to use 0 as default value for stype since pylint does not allow
# importing _STORAGE_TYPE_DEFAULT from ndarray.py.
def _np_ndarray_cls(handle, writable=True, stype=0):
if stype != 0:
Expand All @@ -87,9 +87,9 @@ def _get_index(idx):
if isinstance(idx, NDArray) and not isinstance(idx, ndarray):
raise TypeError('Cannot have mx.nd.NDArray as index')
if isinstance(idx, ndarray):
return idx._as_nd_ndarray()
return idx.as_nd_ndarray()
elif sys.version_info[0] > 2 and isinstance(idx, range):
return array(_np.arange(idx.start, idx.stop, idx.step, dtype=_np.int32))._as_nd_ndarray()
return array(_np.arange(idx.start, idx.stop, idx.step, dtype=_np.int32)).as_nd_ndarray()
else:
return idx

Expand Down Expand Up @@ -135,15 +135,15 @@ def __getitem__(self, key):
return self

if isinstance(key, ndarray):
key = key._as_nd_ndarray()
key = key.as_nd_ndarray()
elif isinstance(key, tuple):
key = [_get_index(idx) for idx in key]
key = tuple(key)
elif isinstance(key, list):
key = [_get_index(idx) for idx in key]
elif sys.version_info[0] > 2 and isinstance(key, range):
key = _get_index(key)
return self._as_nd_ndarray().__getitem__(key).as_np_ndarray()
return self.as_nd_ndarray().__getitem__(key).as_np_ndarray()
# pylint: enable=too-many-return-statements

def __setitem__(self, key, value):
Expand All @@ -154,22 +154,22 @@ def __setitem__(self, key, value):
if not isinstance(key, tuple) or len(key) != 0:
raise IndexError('scalar tensor can only accept `()` as index')
if isinstance(value, ndarray):
value = value._as_nd_ndarray()
value = value.as_nd_ndarray()
# TODO(junwu): Better handling of this situation
if isinstance(key, tuple) and len(key) == 0:
self._as_nd_ndarray().__setitem__(slice(None), value)
self.as_nd_ndarray().__setitem__(slice(None), value)
return

if isinstance(key, ndarray):
key = key._as_nd_ndarray()
key = key.as_nd_ndarray()
elif isinstance(key, tuple):
key = [_get_index(idx) for idx in key]
key = tuple(key)
elif isinstance(key, list):
key = [_get_index(idx) for idx in key]
elif sys.version_info[0] > 2 and isinstance(key, range):
key = _get_index(key)
self._as_nd_ndarray().__setitem__(key, value)
self.as_nd_ndarray().__setitem__(key, value)

def __add__(self, other):
"""x.__add__(y) <=> x + y"""
Expand Down Expand Up @@ -399,21 +399,12 @@ def all(self, axis=None, out=None, keepdims=False):
def any(self, axis=None, out=None, keepdims=False):
raise NotImplementedError

def _as_nd_ndarray(self):
"""This is not a user-facing API."""
def as_nd_ndarray(self):
"""Convert mxnet.numpy.ndarray to mxnet.ndarray.NDArray to use its fluent methods."""
hdl = NDArrayHandle()
check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl)))
return NDArray(handle=hdl, writable=self.writable)

def as_nd_ndarray(self):
"""Convert mxnet.numpy.ndarray to mxnet.ndarray.NDArray to use its fluent methods."""
# TODO(junwu): Uncomment the following lines
# if self.ndim == 0: # TODO(junwu): this costs ~10ns, can be moved to backend
# raise ValueError('cannot convert a scalar np.ndarray to mx.nd.NDArray')
# if self.size == 0: # TODO(junwu): this costs ~10ns, can be moved to backend
# raise ValueError('cannot convert a zero-size np.ndarray to mx.nd.NDArray')
return self._as_nd_ndarray()

def as_np_ndarray(self):
"""A convenience function for creating a numpy ndarray from the current ndarray
with zero copy. For this class, it just returns itself since it's already a
Expand Down Expand Up @@ -580,8 +571,8 @@ def copyto(self, other):
[ 1., 1., 1.]], dtype=float32)
"""
if isinstance(other, ndarray):
other = other._as_nd_ndarray()
return self._as_nd_ndarray().copyto(other).as_np_ndarray()
other = other.as_nd_ndarray()
return self.as_nd_ndarray().copyto(other).as_np_ndarray()

def asscalar(self):
raise AttributeError('mxnet.numpy.ndarray object has no attribute asscalar')
Expand Down Expand Up @@ -1282,7 +1273,7 @@ def tostype(self, stype):


@set_module('mxnet.numpy')
def empty(shape, dtype=None, **kwargs):
def empty(shape, dtype=float, order='C', ctx=None):
"""Return a new array of given shape and type, without initializing entries.
Parameters
Expand All @@ -1293,6 +1284,9 @@ def empty(shape, dtype=None, **kwargs):
`numpy.float32`. Note that this behavior is different from NumPy's `empty`
function where `float64` is the default value, because `float32` is
considered as the default data type in deep learning.
order : {'C'}, optional, default: 'C'
How to store multi-dimensional data in memory, currently only row-major
(C-style) is supported.
ctx : device context, optional
Device context on which the memory is allocated. Default is
`mxnet.context.current_context()`.
Expand All @@ -1302,8 +1296,8 @@ def empty(shape, dtype=None, **kwargs):
out : ndarray
Array of uninitialized (arbitrary) data of the given shape, dtype, and order.
"""
_sanity_check_params('emtpy', ['order'], kwargs)
ctx = kwargs.get('ctx', current_context())
if order != 'C':
raise NotImplementedError
if ctx is None:
ctx = current_context()
if dtype is None:
Expand Down Expand Up @@ -1354,7 +1348,7 @@ def array(object, dtype=None, ctx=None):


@set_module('mxnet.numpy')
def zeros(shape, dtype=_np.float32, **kwargs):
def zeros(shape, dtype=_np.float32, order='C', ctx=None):
"""Return a new array of given shape and type, filled with zeros.
This function currently only supports storing multi-dimensional data
in row-major (C-style).
Expand All @@ -1368,6 +1362,9 @@ def zeros(shape, dtype=_np.float32, **kwargs):
behavior is different from NumPy's `ones` function where `float64`
is the default value, because `float32` is considered as the default
data type in deep learning.
order : {'C'}, optional, default: 'C'
How to store multi-dimensional data in memory, currently only row-major
(C-style) is supported.
ctx : Context, optional
An optional device context (default is the current default context).
Expand All @@ -1376,11 +1373,11 @@ def zeros(shape, dtype=_np.float32, **kwargs):
out : ndarray
Array of zeros with the given shape, dtype, and ctx.
"""
return _mx_nd_np.zeros(shape, dtype, **kwargs)
return _mx_nd_np.zeros(shape, dtype, order, ctx)


@set_module('mxnet.numpy')
def ones(shape, dtype=None, **kwargs):
def ones(shape, dtype=_np.float32, order='C', ctx=None):
"""Return a new array of given shape and type, filled with zeros.
This function currently only supports storing multi-dimensional data
in row-major (C-style).
Expand All @@ -1394,6 +1391,9 @@ def ones(shape, dtype=None, **kwargs):
behavior is different from NumPy's `ones` function where `float64`
is the default value, because `float32` is considered as the default
data type in deep learning.
order : {'C'}, optional, default: 'C'
How to store multi-dimensional data in memory, currently only row-major
(C-style) is supported.
ctx : Context, optional
An optional device context (default is the current default context).
Expand All @@ -1402,7 +1402,7 @@ def ones(shape, dtype=None, **kwargs):
out : ndarray
Array of zeros with the given shape, dtype, and ctx.
"""
return _mx_nd_np.ones(shape, dtype, **kwargs)
return _mx_nd_np.ones(shape, dtype, order, ctx)


@set_module('mxnet.numpy')
Expand Down
Loading

0 comments on commit 345c522

Please sign in to comment.