Skip to content

Commit

Permalink
[numpy] Fix d2l chapter8 (apache#15237)
Browse files Browse the repository at this point in the history
* Add np op doc

* Fix several issues

* Add a N-D dot b 2D support

* Simplify array creation api

* Add swapaxes

* Fix rnn gluon

* More fix

* Fix pylint

* Delete

* Fix mp windows
  • Loading branch information
reminisce authored and haojin2 committed Jul 31, 2019
1 parent 68238a7 commit e92d044
Show file tree
Hide file tree
Showing 24 changed files with 549 additions and 91 deletions.
88 changes: 88 additions & 0 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: skip-file

"""Doc placeholder for numpy ops with prefix _np."""


def _np_reshape(a, newshape, order='C'):
"""Gives a new shape to an array without changing its data.
Parameters
----------
a : ndarray
Array to be reshaped.
newshape : int or tuple of ints
The new shape should be compatible with the original shape. If
an integer, then the result will be a 1-D array of that length.
One shape dimension can be -1. In this case, the value is
inferred from the length of the array and remaining dimensions.
order : {'C'}, optional
Read the elements of `a` using this index order, and place the
elements into the reshaped array using this index order. 'C'
means to read / write the elements using C-like index order,
with the last axis index changing fastest, back to the first
axis index changing slowest. Other order types such as 'F'/'A'
may be added in the future.
Returns
-------
reshaped_array : ndarray
It will be always a copy of the original array. This behavior is different
from the official NumPy package where views of the original array may be
generated.
See Also
--------
ndarray.reshape : Equivalent method.
"""
pass


def _np_ones_like(a):
"""Return an array of ones with the same shape and type as a given array.
Parameters
----------
a : ndarray
The shape and data-type of `a` define these same attributes of
the returned array.
Returns
-------
out : ndarray
Array of ones with the same shape and type as `a`.
"""
pass


def _np_zeros_like(a):
"""Return an array of zeros with the same shape and type as a given array.
Parameters
----------
a : ndarray
The shape and data-type of `a` define these same attributes of
the returned array.
Returns
-------
out : ndarray
Array of zeros with the same shape and type as `a`.
"""
pass
4 changes: 4 additions & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op
make_op_func : function
Function for creating op functions.
"""
from . import _numpy_op_doc as _np_op_doc
if np_module_name == 'numpy':
op_name_prefix = _NP_OP_PREFIX
submodule_name_list = _NP_OP_SUBMODULE_LIST
Expand Down Expand Up @@ -852,3 +853,6 @@ def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op
function.__module__ = module_name_local
setattr(cur_module, function.__name__, function)
cur_module.__all__.append(function.__name__)

if hasattr(_np_op_doc, name):
function.__doc__ = getattr(_np_op_doc, name).__doc__
3 changes: 0 additions & 3 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,9 +470,6 @@ def __next__(self):
batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id))
batch = batch[0] if len(batch) == 1 else batch
self._rcvd_idx += 1
if is_np_array():
new_batch = [member.as_np_ndarray() for member in batch]
batch = new_batch
return batch

def next(self):
Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,9 @@ def __init__(self, input_dim, output_dim, dtype='float32',
init=weight_initializer, dtype=dtype,
allow_deferred_init=True, grad_stype=grad_stype)

@_adapt_np_array
def hybrid_forward(self, F, x, weight):
if is_np_array():
F = F.npx
return F.Embedding(x, weight, name='fwd', **self._kwargs)

def __repr__(self):
Expand Down
33 changes: 20 additions & 13 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from ... import ndarray, symbol
from .. import HybridBlock, tensor_types
from . import rnn_cell
from ...util import is_np_array


class _RNNLayer(HybridBlock):
"""Implementation of recurrent layers."""
Expand Down Expand Up @@ -217,7 +219,10 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
info.update(kwargs)
else:
info = kwargs
states.append(func(name='%sh0_%d'%(self.prefix, i), **info))
state = func(name='%sh0_%d' % (self.prefix, i), **info)
if is_np_array():
state = state.as_np_ndarray()
states.append(state)
return states

def __call__(self, inputs, states=None, sequence_length=None, **kwargs):
Expand All @@ -236,7 +241,6 @@ def __call__(self, inputs, states=None, sequence_length=None, **kwargs):
else:
return super(_RNNLayer, self).__call__(inputs, states, **kwargs)


def hybrid_forward(self, F, inputs, states, sequence_length=None, **kwargs):
if F is ndarray:
batch_size = inputs.shape[self._layout.find('N')]
Expand All @@ -254,8 +258,9 @@ def hybrid_forward(self, F, inputs, states, sequence_length=None, **kwargs):

def _forward_kernel(self, F, inputs, states, sequence_length, **kwargs):
""" forward using CUDNN or CPU kenrel"""
swapaxes = F.np.swapaxes if is_np_array() else F.swapaxes
if self._layout == 'NTC':
inputs = F.swapaxes(inputs, dim1=0, dim2=1)
inputs = swapaxes(inputs, 0, 1)
if self._projection_size is None:
params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1)
for t in ['weight', 'bias']
Expand All @@ -270,29 +275,31 @@ def _forward_kernel(self, F, inputs, states, sequence_length, **kwargs):
for g in ['i2h', 'h2h', 'h2r']
if g != 'h2r' or t != 'bias')

params = F._internal._rnn_param_concat(*params, dim=0)
rnn_param_concat = F.np._internal.rnn_param_concat if is_np_array()\
else F._internal._rnn_param_concat
params = rnn_param_concat(*params, dim=0)

if self._use_sequence_length:
rnn_args = states + [sequence_length]
else:
rnn_args = states

rnn = F.RNN(inputs, params, *rnn_args, use_sequence_length=self._use_sequence_length,
state_size=self._hidden_size, projection_size=self._projection_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode,
lstm_state_clip_min=self._lstm_state_clip_min,
lstm_state_clip_max=self._lstm_state_clip_max,
lstm_state_clip_nan=self._lstm_state_clip_nan)

rnn_fn = F.npx.RNN if is_np_array() else F.RNN
rnn = rnn_fn(inputs, params, *rnn_args, use_sequence_length=self._use_sequence_length,
state_size=self._hidden_size, projection_size=self._projection_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode,
lstm_state_clip_min=self._lstm_state_clip_min,
lstm_state_clip_max=self._lstm_state_clip_max,
lstm_state_clip_nan=self._lstm_state_clip_nan)

if self._mode == 'lstm':
outputs, states = rnn[0], [rnn[1], rnn[2]]
else:
outputs, states = rnn[0], [rnn[1]]

if self._layout == 'NTC':
outputs = F.swapaxes(outputs, dim1=0, dim2=1)
outputs = swapaxes(outputs, 0, 1)

return outputs, states

Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ def _slice(self, start, stop):

check_call(_LIB.MXNDArraySlice(
self.handle, mx_uint(start), mx_uint(stop), ctypes.byref(handle)))
return NDArray(handle=handle, writable=self.writable)
return self.__class__(handle=handle, writable=self.writable)

def _at(self, idx):
"""Returns a view of the array sliced at `idx` in the first dim.
Expand Down Expand Up @@ -1085,7 +1085,7 @@ def reshape(self, *shape, **kwargs):
c_array(ctypes.c_int64, shape),
reverse,
ctypes.byref(handle)))
return NDArray(handle=handle, writable=self.writable)
return self.__class__(handle=handle, writable=self.writable)

def reshape_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`reshape_like`.
Expand Down
45 changes: 44 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
'clip']
'clip', 'swapaxes', 'expand_dims']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -495,3 +495,46 @@ def clip(a, a_min, a_max, out=None):
if a_max is None:
a_max = float('inf')
return _npi.clip(a, a_min, a_max, out=out)


@set_module('mxnet.ndarray.numpy')
def swapaxes(a, axis1, axis2):
"""Interchange two axes of an array.
Parameters
----------
a : ndarray
Input array.
axis1 : int
First axis.
axis2 : int
Second axis.
Returns
-------
a_swapped : ndarray
Swapped array. This is always a copy of the input array.
"""
return _npi.swapaxes(a, dim1=axis1, dim2=axis2)


@set_module('mxnet.ndarray.numpy')
def expand_dims(a, axis):
"""Expand the shape of an array.
Insert a new axis that will appear at the `axis` position in the expanded
Parameters
----------
a : ndarray
Input array.
axis : int
Position in the expanded axes where the new axis is placed.
Returns
-------
res : ndarray
Output array. The number of dimensions is one greater than that of
the input array.
"""
return _npi.expand_dims(a, axis)
2 changes: 0 additions & 2 deletions python/mxnet/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#!/usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
Expand Down
Loading

0 comments on commit e92d044

Please sign in to comment.