Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix numpy import compatibility problem in python2
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Mar 27, 2019
1 parent 101e714 commit 6e44528
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 10 deletions.
2 changes: 0 additions & 2 deletions python/mxnet/ndarray/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import os as _os
import sys as _sys

import numpy as np

try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from .._ctypes.ndarray import NDArrayBase, CachedOp
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# coding: utf-8
# pylint: disable=wildcard-import, unused-wildcard-import,redefined-outer-name
"""Contrib NDArray API of MXNet."""
from __future__ import absolute_import
import math
import numpy as np
from ..context import current_context
Expand Down
7 changes: 4 additions & 3 deletions python/mxnet/ndarray/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
# under the License.

"""Register backend ops in mxnet.ndarray namespace"""
from __future__ import absolute_import
import os as _os
import ctypes
import numpy as np # pylint: disable=unused-import
import numpy as _np # pylint: disable=unused-import

from ._internal import NDArrayBase, _imperative_invoke # pylint: disable=unused-import
from ..ndarray_doc import _build_doc
Expand Down Expand Up @@ -103,7 +104,7 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
if dtype_name is not None:
code.append("""
if '%s' in kwargs:
kwargs['%s'] = np.dtype(kwargs['%s']).name"""%(
kwargs['%s'] = _np.dtype(kwargs['%s']).name"""%(
dtype_name, dtype_name, dtype_name))
code.append("""
_ = kwargs.pop('name', None)
Expand Down Expand Up @@ -136,7 +137,7 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
code.append("""
if %s is not _Null:
keys.append('%s')
vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))

if not signature_only:
code.append("""
Expand Down
2 changes: 0 additions & 2 deletions python/mxnet/symbol/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import sys as _sys
import os as _os

import numpy as np

try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from .._ctypes.symbol import SymbolBase, _set_symbol_class
Expand Down
7 changes: 4 additions & 3 deletions python/mxnet/symbol/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

# pylint: disable=unused-import
"""Register backend ops in mxnet.symbol namespace."""
from __future__ import absolute_import
import os as _os
import ctypes
import numpy as np
import numpy as _np

from . import _internal
from ._internal import SymbolBase, _symbol_creator
Expand Down Expand Up @@ -109,7 +110,7 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
if dtype_name is not None:
code.append("""
if '%s' in kwargs:
kwargs['%s'] = np.dtype(kwargs['%s']).name"""%(
kwargs['%s'] = _np.dtype(kwargs['%s']).name"""%(
dtype_name, dtype_name, dtype_name))
code.append("""
attr = kwargs.pop('attr', None)
Expand Down Expand Up @@ -175,7 +176,7 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
code.append("""
if %s is not _Null:
_keys.append('%s')
_vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))
_vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))

code.append("""
if not hasattr(NameManager._current, "value"):
Expand Down

0 comments on commit 6e44528

Please sign in to comment.