Skip to content

Commit

Permalink
fix for chapter6 conv nn (apache#15224)
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Jul 31, 2019
1 parent 3f6fe79 commit 68238a7
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 6 deletions.
21 changes: 17 additions & 4 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

from . import sampler as _sampler
from ... import nd, context
from ...util import is_np_array
from ...util import is_np_shape, is_np_array, set_np
from ... import numpy as _mx_np # pylint: disable=reimported

if sys.platform == 'darwin' or sys.platform == 'win32':
Expand Down Expand Up @@ -392,14 +392,21 @@ def same_process_iter():
def __len__(self):
return len(self._batch_sampler)


def _thread_worker_initializer(active_shape, active_array):
"""Initializer for ThreadPool."""
set_np(shape=active_shape, array=active_array)


_worker_dataset = None
def _worker_initializer(dataset):
def _worker_initializer(dataset, active_shape, active_array):
"""Initialier for processing pool."""
# global dataset is per-process based and only available in worker processes
# this is only necessary to handle MXIndexedRecordIO because otherwise dataset
# can be passed as argument
global _worker_dataset
_worker_dataset = dataset
set_np(shape=active_shape, array=active_array)

def _worker_fn(samples, batchify_fn, dataset=None):
"""Function for processing data in worker process."""
Expand Down Expand Up @@ -463,6 +470,9 @@ def __next__(self):
batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id))
batch = batch[0] if len(batch) == 1 else batch
self._rcvd_idx += 1
if is_np_array():
new_batch = [member.as_np_ndarray() for member in batch]
batch = new_batch
return batch

def next(self):
Expand Down Expand Up @@ -566,10 +576,13 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers)
if self._num_workers > 0:
if self._thread_pool:
self._worker_pool = ThreadPool(self._num_workers)
self._worker_pool = ThreadPool(self._num_workers,
initializer=_thread_worker_initializer,
initargs=(is_np_shape(), is_np_array()))
else:
self._worker_pool = multiprocessing.Pool(
self._num_workers, initializer=_worker_initializer, initargs=[self._dataset])
self._num_workers, initializer=_worker_initializer,
initargs=[self._dataset, is_np_shape(), is_np_array()])
if batchify_fn is None:
if num_workers > 0:
self._batchify_fn = default_mp_batchify_fn
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def __init__(self, size, keep_ratio=False, interpolation=1):
self._size = size
self._interpolation = interpolation

@_adapt_np_array
def hybrid_forward(self, F, x):
return F.image.resize(x, self._size, self._keep, self._interpolation)

Expand Down
11 changes: 10 additions & 1 deletion python/mxnet/gluon/nn/conv_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ... import symbol
from ...base import numeric_types
from .activations import Activation
from ...util import is_np_array


def _infer_weight_shape(op_name, data_shape, kwargs):
Expand Down Expand Up @@ -109,7 +110,11 @@ def __init__(self, channels, kernel_size, strides, padding, dilation,
if adj is not None:
self._kwargs['adj'] = adj

dshape = [0]*(len(kernel_size) + 2)
if is_np_array():
dshape = [-1]*(len(kernel_size) + 2)
else:
dshape = [0]*(len(kernel_size) + 2)

dshape[layout.find('N')] = 1
dshape[layout.find('C')] = in_channels
wshapes = _infer_weight_shape(op_name, dshape, self._kwargs)
Expand All @@ -129,6 +134,8 @@ def __init__(self, channels, kernel_size, strides, padding, dilation,
self.act = None

def hybrid_forward(self, F, x, weight, bias=None):
if is_np_array():
F = F.npx
if bias is None:
act = getattr(F, self._op_name)(x, weight, name='fwd', **self._kwargs)
else:
Expand Down Expand Up @@ -693,6 +700,8 @@ def _alias(self):
return 'pool'

def hybrid_forward(self, F, x):
if is_np_array():
F = F.npx
return F.Pooling(x, name='fwd', **self._kwargs)

def __repr__(self):
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def _with_np_array(*args, **kwargs):
assert len(args) > 2, "expect at least three arguments in args"
if is_np_array():
input_args, kwargs = _to_classic_arrays(*args[2:], **kwargs)
input_args = list(args[0:2]) + input_args
input_args = list(args[0:2]) + list(input_args)
out = func(*input_args, **kwargs)
return _to_np_arrays(out)
else:
Expand Down
4 changes: 4 additions & 0 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def __getitem__(self, key):
out = out[idx]
return out.reshape(()).as_np_ndarray()
if isinstance(key, integer_types):
if key > self.shape[0] - 1:
raise IndexError(
'index {} is out of bounds for axis 0 with size {}'.format(
key, self.shape[0]))
return self._at(key)
if isinstance(key, ndarray):
key = key._as_nd_ndarray()
Expand Down

0 comments on commit 68238a7

Please sign in to comment.