diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 1923f650ba3f..59b1582831d9 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -38,7 +38,7 @@ from . import sampler as _sampler from ... import nd, context -from ...util import is_np_array +from ...util import is_np_shape, is_np_array, set_np from ... import numpy as _mx_np # pylint: disable=reimported if sys.platform == 'darwin' or sys.platform == 'win32': @@ -392,14 +392,21 @@ def same_process_iter(): def __len__(self): return len(self._batch_sampler) + +def _thread_worker_initializer(active_shape, active_array): + """Initializer for ThreadPool.""" + set_np(shape=active_shape, array=active_array) + + _worker_dataset = None -def _worker_initializer(dataset): +def _worker_initializer(dataset, active_shape, active_array): """Initialier for processing pool.""" # global dataset is per-process based and only available in worker processes # this is only necessary to handle MXIndexedRecordIO because otherwise dataset # can be passed as argument global _worker_dataset _worker_dataset = dataset + set_np(shape=active_shape, array=active_array) def _worker_fn(samples, batchify_fn, dataset=None): """Function for processing data in worker process.""" @@ -463,6 +470,9 @@ def __next__(self): batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id)) batch = batch[0] if len(batch) == 1 else batch self._rcvd_idx += 1 + if is_np_array(): + new_batch = [member.as_np_ndarray() for member in batch] + batch = new_batch return batch def next(self): @@ -566,10 +576,13 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers) if self._num_workers > 0: if self._thread_pool: - self._worker_pool = ThreadPool(self._num_workers) + self._worker_pool = ThreadPool(self._num_workers, + initializer=_thread_worker_initializer, + initargs=(is_np_shape(), is_np_array())) else: self._worker_pool = multiprocessing.Pool( - self._num_workers, initializer=_worker_initializer, initargs=[self._dataset]) + self._num_workers, initializer=_worker_initializer, + initargs=[self._dataset, is_np_shape(), is_np_array()]) if batchify_fn is None: if num_workers > 0: self._batchify_fn = default_mp_batchify_fn diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index 2648997cf5e5..54af87e9de43 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -370,6 +370,7 @@ def __init__(self, size, keep_ratio=False, interpolation=1): self._size = size self._interpolation = interpolation + @_adapt_np_array def hybrid_forward(self, F, x): return F.image.resize(x, self._size, self._keep, self._interpolation) diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py index 4122a08563fa..3e8516b02180 100644 --- a/python/mxnet/gluon/nn/conv_layers.py +++ b/python/mxnet/gluon/nn/conv_layers.py @@ -30,6 +30,7 @@ from ... import symbol from ...base import numeric_types from .activations import Activation +from ...util import is_np_array def _infer_weight_shape(op_name, data_shape, kwargs): @@ -109,7 +110,11 @@ def __init__(self, channels, kernel_size, strides, padding, dilation, if adj is not None: self._kwargs['adj'] = adj - dshape = [0]*(len(kernel_size) + 2) + if is_np_array(): + dshape = [-1]*(len(kernel_size) + 2) + else: + dshape = [0]*(len(kernel_size) + 2) + dshape[layout.find('N')] = 1 dshape[layout.find('C')] = in_channels wshapes = _infer_weight_shape(op_name, dshape, self._kwargs) @@ -129,6 +134,8 @@ def __init__(self, channels, kernel_size, strides, padding, dilation, self.act = None def hybrid_forward(self, F, x, weight, bias=None): + if is_np_array(): + F = F.npx if bias is None: act = getattr(F, self._op_name)(x, weight, name='fwd', **self._kwargs) else: @@ -693,6 +700,8 @@ def _alias(self): return 'pool' def hybrid_forward(self, F, x): + if is_np_array(): + F = F.npx return F.Pooling(x, name='fwd', **self._kwargs) def __repr__(self): diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 63dc1b26aeec..bd69503ccb7d 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -516,7 +516,7 @@ def _with_np_array(*args, **kwargs): assert len(args) > 2, "expect at least three arguments in args" if is_np_array(): input_args, kwargs = _to_classic_arrays(*args[2:], **kwargs) - input_args = list(args[0:2]) + input_args + input_args = list(args[0:2]) + list(input_args) out = func(*input_args, **kwargs) return _to_np_arrays(out) else: diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index a4a05af6802a..409cbf4d6755 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -111,6 +111,10 @@ def __getitem__(self, key): out = out[idx] return out.reshape(()).as_np_ndarray() if isinstance(key, integer_types): + if key > self.shape[0] - 1: + raise IndexError( + 'index {} is out of bounds for axis 0 with size {}'.format( + key, self.shape[0])) return self._at(key) if isinstance(key, ndarray): key = key._as_nd_ndarray()