Skip to content

Commit

Permalink
[numpy] Fix numpy import in python2 (apache#14537)
Browse files Browse the repository at this point in the history
* Fix several test failures

* Fix subgraph op infer shape

* Fix sparse slice

* Fix deconv infer shape

* Fix numpy import compatibility problem in python2
  • Loading branch information
reminisce committed Apr 13, 2019
1 parent 6dd3995 commit b82bf61
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 32 deletions.
2 changes: 0 additions & 2 deletions python/mxnet/ndarray/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import os as _os
import sys as _sys

import numpy as np

try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from .._ctypes.ndarray import NDArrayBase, CachedOp
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# coding: utf-8
# pylint: disable=wildcard-import, unused-wildcard-import,redefined-outer-name
"""Contrib NDArray API of MXNet."""
from __future__ import absolute_import
import math
import numpy as np
from ..context import current_context
Expand Down
7 changes: 4 additions & 3 deletions python/mxnet/ndarray/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# under the License.

"""Register backend ops in mxnet.ndarray namespace"""
from __future__ import absolute_import
import os as _os
import ctypes
import numpy as np # pylint: disable=unused-import
import numpy as _np # pylint: disable=unused-import

from ._internal import NDArrayBase, _imperative_invoke # pylint: disable=unused-import
from ..ndarray_doc import _build_doc
Expand Down Expand Up @@ -103,7 +104,7 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
if dtype_name is not None:
code.append("""
if '%s' in kwargs:
kwargs['%s'] = np.dtype(kwargs['%s']).name"""%(
kwargs['%s'] = _np.dtype(kwargs['%s']).name"""%(
dtype_name, dtype_name, dtype_name))
code.append("""
_ = kwargs.pop('name', None)
Expand Down Expand Up @@ -136,7 +137,7 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
code.append("""
if %s is not _Null:
keys.append('%s')
vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))

if not signature_only:
code.append("""
Expand Down
2 changes: 0 additions & 2 deletions python/mxnet/symbol/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import sys as _sys
import os as _os

import numpy as np

try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from .._ctypes.symbol import SymbolBase, _set_symbol_class
Expand Down
7 changes: 4 additions & 3 deletions python/mxnet/symbol/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

# pylint: disable=unused-import
"""Register backend ops in mxnet.symbol namespace."""
from __future__ import absolute_import
import os as _os
import ctypes
import numpy as np
import numpy as _np

from . import _internal
from ._internal import SymbolBase, _symbol_creator
Expand Down Expand Up @@ -109,7 +110,7 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
if dtype_name is not None:
code.append("""
if '%s' in kwargs:
kwargs['%s'] = np.dtype(kwargs['%s']).name"""%(
kwargs['%s'] = _np.dtype(kwargs['%s']).name"""%(
dtype_name, dtype_name, dtype_name))
code.append("""
attr = kwargs.pop('attr', None)
Expand Down Expand Up @@ -175,7 +176,7 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
code.append("""
if %s is not _Null:
_keys.append('%s')
_vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
_vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))

code.append("""
if not hasattr(NameManager._current, "value"):
Expand Down
5 changes: 5 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,11 @@ inline void ParallelCopy(DType* dst, const DType* src, index_t size) {
* 4. -1 dim size means the dimension's size is unknown.
* so that operator's infer shape function can work in backend.
* \param shape to be converted.
* Note: It is possible that the shape to be converted is already
* numpy compatible. For example, when a subgraph operator's infer
* shape function is called from the infer shape pass of the whole
* graph, its input/output shapes have been converted to numpy
* compatible shapes.
*/
inline void ConvertToNumpyShape(mxnet::TShape* shape) {
if (shape->ndim() == 0) { // legacy shape ndim = 0 means unknown
Expand Down
4 changes: 2 additions & 2 deletions src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,10 @@ class LeakyReLUProp : public OperatorProperty {
CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
}
const mxnet::TShape &dshape = in_shape->at(leakyrelu::kData);
if (dshape.ndim() == 0) return false;
if (!mxnet::ndim_is_known(dshape)) return false;
if (param_.act_type == leakyrelu::kPReLU) {
const mxnet::TShape &gshape = in_shape->at(leakyrelu::kGamma);
if (gshape.ndim() == 0) {
if (!mxnet::ndim_is_known(gshape)) {
in_shape->at(leakyrelu::kGamma) = mxnet::TShape(Shape1(dshape[1]));
}
if (dshape == gshape) {
Expand Down
12 changes: 7 additions & 5 deletions src/operator/nn/deconvolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,13 @@ struct DeconvolutionParam : public dmlc::Parameter<DeconvolutionParam> {
for (size_t i = 0; i < ndim; i++) {
// input.ndim() can be larger than ndim, in case that the complete input
// shape was passed and not only the ndim last ones
o_pad[i] = stride[i] * (input[(input_ndim - ndim) + i] - 1) + DilatedKernelSize(i);
CHECK_GE(o_pad[i], target_shape[i]) << "too big target shape";
o_pad[i] -= target_shape[i];
o_adj[i] = o_pad[i] % 2;
o_pad[i] = (o_pad[i] + 1) / 2;
if (mxnet::dim_size_is_known(input, input_ndim - ndim + i)) {
o_pad[i] = stride[i] * (input[(input_ndim - ndim) + i] - 1) + DilatedKernelSize(i);
CHECK_GE(o_pad[i], target_shape[i]) << "too big target shape";
o_pad[i] -= target_shape[i];
o_adj[i] = o_pad[i] % 2;
o_pad[i] = (o_pad[i] + 1) / 2;
}
}
} else {
for (size_t i = 0; i < ndim; i++) {
Expand Down
50 changes: 37 additions & 13 deletions src/operator/nn/deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs,
}
out_shape->resize(1, mxnet::TShape());
const mxnet::TShape &dshape = (*in_shape)[deconv::kData];
if (!shape_is_known(dshape)) return false;
if (!mxnet::ndim_is_known(dshape)) return false;

if (param_.kernel.ndim() == 1) {
// 1d conv
Expand Down Expand Up @@ -90,8 +90,12 @@ static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs,
Shape<3> oshape;
oshape[0] = dshape_ncw[0];
oshape[1] = param_.num_filter;
oshape[2] = param_.stride[0] * (dshape_ncw[2] - 1) +
dilated_ksize_x - 2 * o_pad[0] + o_adj[0];
if (mxnet::dim_size_is_known(dshape_ncw[2])) {
oshape[2] = param_.stride[0] * (dshape_ncw[2] - 1) +
dilated_ksize_x - 2 * o_pad[0] + o_adj[0];
} else {
oshape[2] = -1;
}

if (param_.target_shape.ndim() > 0) {
if (param_.target_shape[0] > 0) {
Expand Down Expand Up @@ -141,10 +145,18 @@ static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs,
Shape<4> oshape;
oshape[0] = dshape_nchw[0];
oshape[1] = param_.num_filter;
oshape[2] = param_.stride[0] * (dshape_nchw[2] - 1) +
dilated_ksize_y - 2 * o_pad[0] + o_adj[0];
oshape[3] = param_.stride[1] * (dshape_nchw[3] - 1) +
dilated_ksize_x - 2 * o_pad[1] + o_adj[1];
if (mxnet::dim_size_is_known(dshape_nchw[2])) {
oshape[2] = param_.stride[0] * (dshape_nchw[2] - 1) +
dilated_ksize_y - 2 * o_pad[0] + o_adj[0];
} else {
oshape[2] = -1;
}
if (mxnet::dim_size_is_known(dshape_nchw[3])) {
oshape[3] = param_.stride[1] * (dshape_nchw[3] - 1) +
dilated_ksize_x - 2 * o_pad[1] + o_adj[1];
} else {
oshape[3] = -1;
}

if (param_.target_shape.ndim() > 1) {
if (param_.target_shape[0] > 0) {
Expand Down Expand Up @@ -203,12 +215,24 @@ static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs,
Shape<5> oshape;
oshape[0] = dshape_ncdhw[0];
oshape[1] = param_.num_filter;
oshape[2] = param_.stride[0] * (dshape_ncdhw[2] - 1) +
dilated_ksize_d - 2 * o_pad[0] + o_adj[0];
oshape[3] = param_.stride[1] * (dshape_ncdhw[3] - 1) +
dilated_ksize_y - 2 * o_pad[1] + o_adj[1];
oshape[4] = param_.stride[2] * (dshape_ncdhw[4] - 1) +
dilated_ksize_x - 2 * o_pad[2] + o_adj[2];
if (mxnet::dim_size_is_known(dshape_ncdhw[2])) {
oshape[2] = param_.stride[0] * (dshape_ncdhw[2] - 1) +
dilated_ksize_d - 2 * o_pad[0] + o_adj[0];
} else {
oshape[2] = -1;
}
if (mxnet::dim_size_is_known(dshape_ncdhw[3])) {
oshape[3] = param_.stride[1] * (dshape_ncdhw[3] - 1) +
dilated_ksize_y - 2 * o_pad[1] + o_adj[1];
} else {
oshape[3] = -1;
}
if (mxnet::dim_size_is_known(dshape_ncdhw[4])) {
oshape[4] = param_.stride[2] * (dshape_ncdhw[4] - 1) +
dilated_ksize_x - 2 * o_pad[2] + o_adj[2];
} else {
oshape[4] = -1;
}

if (param_.target_shape.ndim() > 2) {
if (param_.target_shape[0] > 0) {
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ void SliceCsrImpl(const SliceParam &param, const OpContext& ctx,
mxnet::TShape begin(N, -1), end(N, -1);
for (int i = 0; i < N; ++i) {
int s = 0;
if (param.begin[i]) {
if (i < param.begin.ndim() && param.begin[i]) {
s = *param.begin[i];
if (s < 0) s += ishape[i];
}
Expand Down
6 changes: 5 additions & 1 deletion tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ def test_ndarray_setitem():

# numpy assignment for empty axis
for trivial_shape in [(), (1,), (1, 1), (1, 1, 1)]:
x = mx.nd.zeros(trivial_shape)
if trivial_shape == tuple():
with mx.numpy.enable_np_comp():
x = mx.nd.zeros(trivial_shape)
else:
x = mx.nd.zeros(trivial_shape)
x[:] = np.ones(trivial_shape)
x_np = np.ones(trivial_shape, dtype=x.dtype)
assert x.shape == trivial_shape
Expand Down

0 comments on commit b82bf61

Please sign in to comment.