Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[numpy] Refactor np modules #14989

Merged
merged 11 commits into from
May 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 25 additions & 48 deletions example/numpy/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@
"source": [
"# Fundamentals of MXNet Numpy Module\n",
"\n",
"## Operator Namespaces for Imperative Programming\n",
"## Namespaces for Imperative Programming\n",
"- `mxnet.numpy`: Regular NumPy operators\n",
"- `mxnet.numpy.random`: NumPy random operators\n",
"- `mxnet.numpy.linalg`: NumPy linear algebra operators\n",
"- `mxnet.numpy.ext`: Operators implemented in MXNet that do not exist in official NumPy\n",
"- `mxnet.numpy_extension`: Operators implemented in MXNet that do not exist in the official NumPy\n",
"\n",
"## Operator Namespaces for Gluon\n",
"`F` can be either `mxnet.ndarray` or `mxnet.symbol`.\n",
"`F` can be either `mxnet.ndarray` or `mxnet.symbol`. Note that `np` and `npe` are aliases of `numpy` and `numpy_extension`, respectively.\n",
"- `F.np`: Regular NumPy operators\n",
"- `F.np.random`: NumPy random operators\n",
"- `F.np.linalg`: NumPy linear algebra operators\n",
"- `F.np.ext`: Operators implemented in MXNet that do not exist in official NumPy\n",
"- `F.npe`: Operators implemented in MXNet that do not exist in official NumPy\n",
"\n",
"## New `ndarray` and `symbol`\n",
"`mxnet.numpy.ndarray` and `mxnet.symbol.numpy._NumpySymbol` (not visible to users)\n",
"`mxnet.numpy.ndarray` (visible to users) and `mxnet.symbol.numpy._Symbol` (not visible to users)\n",
"- Same name as in the official NumPy package\n",
"- Dispatch convience fluent method calls to MXNet Numpy operators\n",
"- Override many convenience fluent methods that do not exist in the official NumPy ndarray\n",
Expand All @@ -46,7 +46,7 @@
"\n",
"# create a scalar tensor\n",
"x = np.array(3.14)\n",
"print(x)"
"print(x) # x is actually an ndarray, but a scalar value will be printed"
]
},
{
Expand Down Expand Up @@ -170,13 +170,15 @@
"from mxnet import gluon\n",
"class TestBinaryBroadcast(gluon.HybridBlock):\n",
" def hybrid_forward(self, F, x1, x2):\n",
" print(\"x1 type:\", str(type(x1)))\n",
" print(\"x2 type:\", str(type(x2)))\n",
" print(\"x1 type in hybrid_forward:\", str(type(x1)))\n",
" print(\"x2 type in hybrid_forward:\", str(type(x2)))\n",
" return x1 + x2\n",
"\n",
"net = TestBinaryBroadcast()\n",
"x1 = mx.nd.ones((2, 1))\n",
"x2 = mx.nd.ones((1, 3))\n",
"print('x1 input tensor type: ', str(type(x1)))\n",
"print('x2 input tensor type: ', str(type(x2)))\n",
"out = net(x1, x2) # ok: imperative execution supports broadcasting\n",
"print(out)"
]
Expand All @@ -203,13 +205,15 @@
"source": [
"class TestBinaryBroadcast2(gluon.HybridBlock):\n",
" def hybrid_forward(self, F, x1, x2):\n",
" print(\"x1 type:\", str(type(x1)))\n",
" print(\"x2 type:\", str(type(x2)))\n",
" print(\"x1 type in hybrid_forward:\", str(type(x1)))\n",
" print(\"x2 type in hybrid_forward:\", str(type(x2)))\n",
" return x1.as_np_ndarray() + x2 # convert x1 to new numpy ndarray/symbol\n",
"\n",
"net2 = TestBinaryBroadcast2()\n",
"net2.hybridize()\n",
"\n",
"print('x1 input tensor type: ', str(type(x1)))\n",
"print('x2 input tensor type: ', str(type(x2)))\n",
"out =net2(x1, x2)\n",
"print(out)"
]
Expand All @@ -224,7 +228,9 @@
"net.hybridize() # mark the block for execution using a computational graph\n",
"\n",
"x1 = x1.as_np_ndarray() # convert x1 to np.ndarray so that _NumpySymbol will be used in graph construction\n",
"print('x1 input tensor type: ', str(type(x1)))\n",
"x2 = x2.as_np_ndarray() # convert x2 to np.ndarray so that _NumpySymbol will be used in graph construction\n",
"print('x2 input tensor type: ', str(type(x2)))\n",
"out = net(x1, x2) # ok: `+` operation supports broadcasting for _NumpySymbol\n",
"print(out) # mxnet.numpy.ndarray type, because it's from a np operator"
]
Expand All @@ -245,7 +251,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## MXNet Numpy Operators in Imperative Programming"
"### MXNet Numpy Operators in Imperative Programming"
]
},
{
Expand All @@ -255,15 +261,9 @@
"outputs": [],
"source": [
"import mxnet as mx\n",
"from mxnet import numpy as np\n",
"from mxnet import numpy as np, numpy_extension as npe\n",
"from mxnet import autograd\n",
"try:\n",
" from mxboard import SummaryWriter\n",
"except ImportError:\n",
" SummaryWriter = None\n",
"\n",
"# create a summary writer for visualization\n",
"sw = SummaryWriter(logdir='./logs', flush_secs=2) if SummaryWriter is not None else None\n",
"\n",
"# Use numpy-compatible semantics to support scalar tensors\n",
"mx.set_np_compat(True)\n",
Expand All @@ -285,11 +285,11 @@
"learning_rate = 1e-6\n",
"\n",
"\n",
"for t in range(1000):\n",
"for t in range(50):\n",
" with autograd.record():\n",
" # Forward pass: compute predicted y\n",
" h = x.dot(w1) # equivalent to np.dot(x, w1)\n",
" h_relu = np.ext.relu(h) # equivalent to mx.nd.relu(h)\n",
" h_relu = npe.relu(h) # equivalent to mx.nd.relu(h)\n",
" y_pred = h_relu.dot(w2) # equivalent to np.dot(h_relu, w2)\n",
"\n",
" # Compute loss\n",
Expand All @@ -302,23 +302,14 @@
"\n",
" # Update weights\n",
" w1 -= learning_rate * w1.grad\n",
" w2 -= learning_rate * w2.grad\n",
"\n",
" if sw is not None:\n",
" sw.add_scalar('loss', loss.item(), global_step=t) # loss.item() copies the tensor element to a python scalar\n",
" if t % 50 == 0:\n",
" sw.add_histogram(tag='w1', values=w1, global_step=t)\n",
" sw.add_histogram(tag='w2', values=w2, global_step=t)\n",
"\n",
"if sw is not None:\n",
" sw.close()"
" w2 -= learning_rate * w2.grad"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## MXNet Numpy Operators in Gluon `HybridBlock`"
"### MXNet Numpy Operators in Gluon `HybridBlock`"
]
},
{
Expand All @@ -329,13 +320,7 @@
"source": [
"import mxnet as mx\n",
"from mxnet import gluon, autograd\n",
"try:\n",
" from mxboard import SummaryWriter\n",
"except ImportError:\n",
" SummaryWriter = None\n",
"\n",
"# create a summary writer for visualization\n",
"sw = SummaryWriter(logdir='./logs', flush_secs=2) if SummaryWriter is not None else None\n",
"\n",
"# Use numpy-compatible semantics to support scalar tensors\n",
"mx.set_np_compat(True)\n",
Expand All @@ -352,7 +337,7 @@
"\n",
" def hybrid_forward(self, F, x, w1, w2):\n",
" h = x.dot(w1) # equivalent to F.np.dot(x, w1)\n",
" h_relu = F.np.ext.relu(h) # equivalent to F.relu(h)\n",
" h_relu = F.npe.relu(h) # equivalent to F.relu(h)\n",
" y_pred = h_relu.dot(w2) # equivalent to F.np.dot(h_relu, w2)\n",
" return y_pred\n",
"\n",
Expand All @@ -373,21 +358,13 @@
"total_loss = TotalLoss()\n",
"trainer = gluon.Trainer(regressor.collect_params(), 'sgd', {'learning_rate': 1e-3, 'momentum': 0.9})\n",
"\n",
"for t in range(1000):\n",
"for t in range(50):\n",
" with autograd.record():\n",
" output = regressor(x) # output is a type of np.ndarray because np.dot is the last op in the network\n",
" loss = total_loss(output, y) # loss is a scalar np.ndarray\n",
" loss.backward()\n",
" print(t, loss) # note that loss.asnumpy() is called\n",
" trainer.step(1)\n",
" if sw is not None:\n",
" sw.add_scalar('loss', loss.item(), global_step=t) # loss.item() copies the tensor element to a python scalar\n",
" if t % 50 == 0:\n",
" for k, v in regressor.collect_params().items():\n",
" sw.add_histogram(tag=k, values=v.data(), global_step=t)\n",
"\n",
"if sw is not None:\n",
" sw.close()"
" trainer.step(1)"
]
}
],
Expand Down
17 changes: 0 additions & 17 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2788,14 +2788,6 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
EngineVarHandle mutable_vars_handle, int num_mutable_vars,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
/*!
* \brief Determines if an op is a Numpy op by its name prefix.
* Every Numpy op starts with a prefix string "_numpy_".
* \param creator Operator handle
* \param is_np_op Indicator of whether creator is a numpy op handle
*/
MXNET_DLL int MXIsNumpyCompatOp(AtomicSymbolCreator creator,
int* is_np_op);
/*!
* \brief Create an NDArray from source sharing the same data chunk.
* \param src source NDArray
Expand All @@ -2808,15 +2800,6 @@ MXNET_DLL int MXShallowCopyNDArray(NDArrayHandle src, NDArrayHandle* out);
* \param out new Symbol sharing the same graph structure with src
*/
MXNET_DLL int MXShallowCopySymbol(SymbolHandle src, SymbolHandle * out);
/*!
* \brief Checks if an output of CachedOp is from a numpy op.
* \param handle CachedOp shared ptr
* \param output_idx index of the output of the CachedOp
* \param is_from_np_op indicator of whether the output is from a numpy op
*/
MXNET_DLL int MXIsCachedOpOutputFromNumpyCompatOp(CachedOpHandle handle,
int output_idx,
int* is_from_np_op);

#ifdef __cplusplus
}
Expand Down
9 changes: 0 additions & 9 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,15 +319,6 @@ using FNeedRequantize = std::function<bool (const NodeAttrs& attrs)>;
using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
size_t index)>;

/*!
* \brief Indicates whether this operator is NumPy compatible.
* It is for distinguishing the operator from classic MXNet operators
* which do not support zero-dim and zero-size tensors.
* In Python, it is used to determine whether to output numpy ndarrays
* or symbols that are NumPy compatible.
*/
using TIsNumpyCompatible = bool;

} // namespace mxnet

#endif // MXNET_OP_ATTR_TYPES_H_
3 changes: 3 additions & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from . import ndarray
from . import ndarray as nd
from . import numpy
from . import numpy_extension
from . import numpy as np
from . import numpy_extension as npe
from . import name
# use mx.sym as short for symbol
from . import symbol as sym
Expand Down
19 changes: 10 additions & 9 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..base import _LIB
from ..base import c_str_array, c_handle_array
from ..base import NDArrayHandle, CachedOpHandle
from ..base import check_call, _is_np_compat_op
from ..base import check_call


class NDArrayBase(object):
Expand Down Expand Up @@ -70,7 +70,7 @@ def _set_np_ndarray_class(cls):
_np_ndarray_cls = cls


def _imperative_invoke(handle, ndargs, keys, vals, out):
def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op):
"""ctypes implementation of imperative invoke wrapper"""
if out is not None:
original_output = out
Expand Down Expand Up @@ -99,9 +99,9 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
c_str_array([str(s) for s in vals]),
ctypes.byref(out_stypes)))

create_ndarray_fn = _np_ndarray_cls if is_np_op else _ndarray_cls
if original_output is not None:
return original_output
create_ndarray_fn = _np_ndarray_cls if _is_np_compat_op(handle) else _ndarray_cls
if num_output.value == 1:
return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle),
stype=out_stypes[0])
Expand All @@ -112,11 +112,14 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):

class CachedOp(object):
"""Cached operator handle."""
__slots__ = ["handle"]
__slots__ = ["handle", "is_np_sym"]

def __init__(self, sym, flags=()):
self.handle = CachedOpHandle()

from ..symbol.numpy._symbol import _Symbol
self.is_np_sym = True if isinstance(sym, _Symbol) else False

check_call(_LIB.MXCreateCachedOpEx(
sym.handle,
len(flags),
Expand Down Expand Up @@ -167,12 +170,10 @@ def __call__(self, *args, **kwargs):

if original_output is not None:
return original_output
create_ndarray_fn = _np_ndarray_cls if self.is_np_sym else _ndarray_cls
if num_output.value == 1:
create_ndarray_fn = _np_ndarray_cls if self._is_from_np_compat_op(0) else _ndarray_cls
return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle),
stype=out_stypes[0])
else:
return [_np_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i])
if self._is_from_np_compat_op(i) else
_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i])
for i in range(num_output.value)]
return [create_ndarray_fn(ctypes.cast(output_vars[i], NDArrayHandle),
stype=out_stypes[i]) for i in range(num_output.value)]
10 changes: 4 additions & 6 deletions python/mxnet/_ctypes/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import ctypes
from ..base import _LIB
from ..base import c_str_array, c_handle_array, c_str, mx_uint, _is_np_compat_op
from ..base import c_str_array, c_handle_array, c_str, mx_uint
from ..base import SymbolHandle
from ..base import check_call

Expand Down Expand Up @@ -122,7 +122,7 @@ def _set_np_symbol_class(cls):
_np_symbol_cls = cls


def _symbol_creator(handle, args, kwargs, keys, vals, name):
def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op):
sym_handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateAtomicSymbol(
ctypes.c_void_p(handle),
Expand All @@ -135,10 +135,8 @@ def _symbol_creator(handle, args, kwargs, keys, vals, name):
raise TypeError(
'Operators with variable length input can only accept input'
'Symbols either as positional or keyword arguments, not both')
if _is_np_compat_op(handle):
s = _np_symbol_cls(sym_handle)
else:
s = _symbol_cls(sym_handle)
create_symbol_fn = _np_symbol_cls if is_np_op else _symbol_cls
s = create_symbol_fn(sym_handle)
if args:
s._compose(*args, name=name)
elif kwargs:
Expand Down
Loading