From 27de2d6506c5b980f54f3eaac31bbc27333c9ab3 Mon Sep 17 00:00:00 2001 From: reminisce Date: Wed, 26 Jun 2019 20:35:06 -0700 Subject: [PATCH] [numpy] Change d2l chapters cv and gan to use numpy (#15368) * Change op name style to lower case underscore * Add ops under image to npx * Add image submodule to npx * Fix split_and_load use np * Fix fine tuning * Fix bbox and anchor * Fix odd * Fix ssd and rcnn * Remove restriction on binary element-wise scalar * Fix gan * Fix sanity * Try to fix website build failure * Add npx.random.seed * Fix doc --- python/mxnet/_numpy_op_doc.py | 5 +- python/mxnet/base.py | 3 +- python/mxnet/gluon/block.py | 23 +++++- python/mxnet/gluon/data/vision/datasets.py | 5 +- python/mxnet/gluon/data/vision/transforms.py | 28 ++++++- python/mxnet/gluon/loss.py | 39 +++++++--- python/mxnet/gluon/model_zoo/vision/resnet.py | 19 +++-- python/mxnet/gluon/nn/activations.py | 8 +- python/mxnet/gluon/nn/basic_layers.py | 26 +++---- python/mxnet/gluon/nn/conv_layers.py | 47 +++++++++--- python/mxnet/gluon/rnn/rnn_layer.py | 2 +- python/mxnet/gluon/utils.py | 25 ++++--- python/mxnet/image/detection.py | 17 ++++- python/mxnet/image/image.py | 44 ++++++++--- .../mxnet/ndarray/numpy_extension/__init__.py | 1 + python/mxnet/ndarray/numpy_extension/image.py | 20 +++++ python/mxnet/numpy/__init__.py | 1 + python/mxnet/numpy/arrayprint.py | 62 ++++++++++++++++ python/mxnet/numpy/multiarray.py | 53 +++++++++++-- python/mxnet/numpy_extension/__init__.py | 2 + python/mxnet/numpy_extension/image.py | 20 +++++ python/mxnet/numpy_extension/random.py | 74 +++++++++++++++++++ .../mxnet/symbol/numpy_extension/__init__.py | 1 + python/mxnet/symbol/numpy_extension/image.py | 20 +++++ src/io/image_io.cc | 3 + src/ndarray/ndarray.cc | 2 +- src/operator/contrib/multibox_detection.cc | 4 + src/operator/contrib/multibox_prior.cc | 3 + src/operator/contrib/multibox_target.cc | 4 + src/operator/image/crop.cc | 1 + src/operator/image/image_random.cc | 13 ++++ src/operator/image/resize.cc | 1 + src/operator/leaky_relu.cc | 1 + src/operator/nn/activation.cc | 2 +- src/operator/nn/batch_norm.cc | 2 +- src/operator/nn/convolution.cc | 2 +- src/operator/nn/deconvolution.cc | 1 + src/operator/nn/dropout.cc | 2 +- src/operator/nn/fully_connected.cc | 2 +- src/operator/nn/layer_norm.cc | 2 +- src/operator/nn/pooling.cc | 2 +- .../numpy/np_elemwise_broadcast_op.cc | 11 +-- src/operator/rnn.cc | 2 +- src/operator/roi_pooling.cc | 4 + src/operator/sequence_mask.cc | 2 +- .../elemwise_binary_scalar_op_extended.cc | 3 +- .../tensor/elemwise_unary_op_basic.cc | 1 + src/operator/tensor/indexing_op.cc | 2 +- 48 files changed, 505 insertions(+), 112 deletions(-) create mode 100644 python/mxnet/ndarray/numpy_extension/image.py create mode 100644 python/mxnet/numpy/arrayprint.py create mode 100644 python/mxnet/numpy_extension/image.py create mode 100644 python/mxnet/numpy_extension/random.py create mode 100644 python/mxnet/symbol/numpy_extension/image.py diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 995a65c9ca65..ca8636cf5029 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -21,7 +21,10 @@ def _np_reshape(a, newshape, order='C'): - """Gives a new shape to an array without changing its data. + """ + reshape(a, newshape, order='C') + + Gives a new shape to an array without changing its data. Parameters ---------- diff --git a/python/mxnet/base.py b/python/mxnet/base.py index a4f75c6d4713..545c2ea4eb19 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -744,6 +744,7 @@ def write_all_str(module_file, module_all_list): _NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_'] _NP_EXT_OP_PREFIX = '_npx_' +_NP_EXT_OP_SUBMODULE_LIST = ['_image_'] _NP_INTERNAL_OP_PREFIX = '_npi_' @@ -784,7 +785,7 @@ def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op submodule_name_list = _NP_OP_SUBMODULE_LIST elif np_module_name == 'numpy_extension': op_name_prefix = _NP_EXT_OP_PREFIX - submodule_name_list = [] + submodule_name_list = _NP_EXT_OP_SUBMODULE_LIST elif np_module_name == 'numpy._internal': op_name_prefix = _NP_INTERNAL_OP_PREFIX submodule_name_list = [] diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 09a2e2a492db..4516952f2fdc 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -26,7 +26,6 @@ import re from collections import OrderedDict - from ..base import mx_real_t, MXNetError from .. import symbol, ndarray, initializer from ..symbol import Symbol @@ -37,7 +36,7 @@ from .utils import _check_same_symbol_type, _check_all_np_ndarrays from .. import numpy_extension as _mx_npx from .. import numpy as _mx_np, numpy_extension as _mx_npx -from .. util import is_np_array +from .. util import is_np_array, np_shape, np_array class _BlockScope(object): @@ -387,7 +386,25 @@ def load_parameters(self, filename, ctx=None, allow_missing=False, `_ """ if is_np_array(): - loaded = _mx_npx.load(filename) + # failure may happen when loading parameters saved as NDArrays within + # NumPy semantics. Check the failure type and recover from it if it happens. + try: + loaded = _mx_npx.load(filename) + except MXNetError as e: + err_msg = str(e) + if 'is_np_shape' in err_msg: + # Loading failure due to parameters saved without numpy semantics. + # Temporarily disable numpy semantics and load parameters. After it's + # done, resume the numpy semantics. This is fine because the cases + # numpy ndarray covers is a superset of the legacy ndarray's. + with np_array(False): + with np_shape(False): + loaded_nds = ndarray.load(filename) + assert isinstance(loaded_nds, dict),\ + 'expecting a dict type, got {}'.format(str(type(loaded_nds))) + loaded = {k: loaded_nds[k].as_np_ndarray() for k in loaded_nds} + else: + raise ValueError(err_msg) else: loaded = ndarray.load(filename) params = self._collect_params_with_prefix() diff --git a/python/mxnet/gluon/data/vision/datasets.py b/python/mxnet/gluon/data/vision/datasets.py index 362cc9ee6515..bdcaff52a042 100644 --- a/python/mxnet/gluon/data/vision/datasets.py +++ b/python/mxnet/gluon/data/vision/datasets.py @@ -188,8 +188,9 @@ def _get_data(self): data = np.concatenate(data) label = np.concatenate(label) - self._data = nd.array(data, dtype=data.dtype) - self._label = label + array_fn = _mx_np.array if is_np_array() else nd.array + self._data = array_fn(data, dtype=data.dtype) + self._label = array_fn(label, dtype=label.dtype) if is_np_array() else label class CIFAR100(CIFAR10): diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index 54af87e9de43..ab8f8ab482df 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -23,7 +23,7 @@ from ...nn import Sequential, HybridSequential from .... import image from ....base import numeric_types -from ...utils import _adapt_np_array +from ....util import is_np_array class Compose(Sequential): @@ -93,6 +93,8 @@ def __init__(self, dtype='float32'): self._dtype = dtype def hybrid_forward(self, F, x): + if is_np_array(): + F = F.npx return F.cast(x, self._dtype) @@ -134,8 +136,9 @@ class ToTensor(HybridBlock): def __init__(self): super(ToTensor, self).__init__() - @_adapt_np_array def hybrid_forward(self, F, x): + if is_np_array(): + F = F.npx return F.image.to_tensor(x) @@ -189,6 +192,8 @@ def __init__(self, mean=0.0, std=1.0): self._std = std def hybrid_forward(self, F, x): + if is_np_array(): + F = F.npx return F.image.normalize(x, self._mean, self._std) @@ -370,8 +375,9 @@ def __init__(self, size, keep_ratio=False, interpolation=1): self._size = size self._interpolation = interpolation - @_adapt_np_array def hybrid_forward(self, F, x): + if is_np_array(): + F = F.npx return F.image.resize(x, self._size, self._keep, self._interpolation) class RandomFlipLeftRight(HybridBlock): @@ -388,6 +394,8 @@ def __init__(self): super(RandomFlipLeftRight, self).__init__() def hybrid_forward(self, F, x): + if is_np_array(): + F = F.npx return F.image.random_flip_left_right(x) @@ -405,6 +413,8 @@ def __init__(self): super(RandomFlipTopBottom, self).__init__() def hybrid_forward(self, F, x): + if is_np_array(): + F = F.npx return F.image.random_flip_top_bottom(x) @@ -430,6 +440,8 @@ def __init__(self, brightness): self._args = (max(0, 1-brightness), 1+brightness) def hybrid_forward(self, F, x): + if is_np_array(): + F = F.npx return F.image.random_brightness(x, *self._args) @@ -455,6 +467,8 @@ def __init__(self, contrast): self._args = (max(0, 1-contrast), 1+contrast) def hybrid_forward(self, F, x): + if is_np_array(): + F = F.npx return F.image.random_contrast(x, *self._args) @@ -480,6 +494,8 @@ def __init__(self, saturation): self._args = (max(0, 1-saturation), 1+saturation) def hybrid_forward(self, F, x): + if is_np_array(): + F = F.npx return F.image.random_saturation(x, *self._args) @@ -505,6 +521,8 @@ def __init__(self, hue): self._args = (max(0, 1-hue), 1+hue) def hybrid_forward(self, F, x): + if is_np_array(): + F = F.npx return F.image.random_hue(x, *self._args) @@ -539,6 +557,8 @@ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): self._args = (brightness, contrast, saturation, hue) def hybrid_forward(self, F, x): + if is_np_array(): + F = F.npx return F.image.random_color_jitter(x, *self._args) @@ -562,4 +582,6 @@ def __init__(self, alpha): self._alpha = alpha def hybrid_forward(self, F, x): + if is_np_array(): + F = F.npx return F.image.random_lighting(x, self._alpha) diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index 6c66d4c7468d..d634e7922fae 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -258,30 +258,47 @@ def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, **kwargs): weight, batch_axis, **kwargs) self._from_sigmoid = from_sigmoid - @_adapt_np_array def hybrid_forward(self, F, pred, label, sample_weight=None, pos_weight=None): label = _reshape_like(F, label, pred) + if is_np_array(): + relu_fn = F.npx.relu + act_fn = F.npx.activation + abs_fn = F.np.abs + mul_fn = F.np.multiply + log_fn = F.np.log + else: + relu_fn = F.relu + act_fn = F.Activation + abs_fn = F.abs + mul_fn = F.broadcast_mul + log_fn = F.log if not self._from_sigmoid: if pos_weight is None: # We use the stable formula: max(x, 0) - x * z + log(1 + exp(-abs(x))) - loss = F.relu(pred) - pred * label + \ - F.Activation(-F.abs(pred), act_type='softrelu') + loss = relu_fn(pred) - pred * label + \ + act_fn(-abs_fn(pred), act_type='softrelu') else: # We use the stable formula: x - x * z + (1 + z * pos_weight - z) * \ # (log(1 + exp(-abs(x))) + max(-x, 0)) - log_weight = 1 + F.broadcast_mul(pos_weight - 1, label) - loss = pred - pred * label + log_weight * \ - (F.Activation(-F.abs(pred), act_type='softrelu') + F.relu(-pred)) + log_weight = 1 + mul_fn(pos_weight - 1, label) + loss = pred - pred * label + log_weight *\ + (act_fn(-abs_fn(pred), act_type='softrelu') + relu_fn(-pred)) else: eps = 1e-12 if pos_weight is None: - loss = -(F.log(pred + eps) * label - + F.log(1. - pred + eps) * (1. - label)) + loss = -(log_fn(pred + eps) * label + + log_fn(1. - pred + eps) * (1. - label)) else: - loss = -(F.broadcast_mul(F.log(pred + eps) * label, pos_weight) - + F.log(1. - pred + eps) * (1. - label)) + loss = -(mul_fn(log_fn(pred + eps) * label, pos_weight) + + log_fn(1. - pred + eps) * (1. - label)) loss = _apply_weighting(F, loss, self._weight, sample_weight) - return F.mean(loss, axis=self._batch_axis, exclude=True) + if is_np_array(): + if F is ndarray: + return F.np.mean(loss, axis=tuple(range(1, loss.ndim))) + else: + return F.npx.batch_flatten(loss).mean(axis=1) + else: + return F.mean(loss, axis=self._batch_axis, exclude=True) SigmoidBCELoss = SigmoidBinaryCrossEntropyLoss diff --git a/python/mxnet/gluon/model_zoo/vision/resnet.py b/python/mxnet/gluon/model_zoo/vision/resnet.py index 48390decb11b..50a65ec8d2da 100644 --- a/python/mxnet/gluon/model_zoo/vision/resnet.py +++ b/python/mxnet/gluon/model_zoo/vision/resnet.py @@ -33,6 +33,7 @@ from ...block import HybridBlock from ... import nn from .... import base +from .... util import is_np_array # Helpers def _conv3x3(channels, stride, in_channels): @@ -81,7 +82,8 @@ def hybrid_forward(self, F, x): if self.downsample: residual = self.downsample(residual) - x = F.Activation(residual+x, act_type='relu') + act = F.npx.activation if is_np_array() else F.Activation + x = act(residual+x, act_type='relu') return x @@ -129,7 +131,8 @@ def hybrid_forward(self, F, x): if self.downsample: residual = self.downsample(residual) - x = F.Activation(x + residual, act_type='relu') + act = F.npx.activation if is_np_array() else F.Activation + x = act(x + residual, act_type='relu') return x @@ -165,13 +168,14 @@ def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs): def hybrid_forward(self, F, x): residual = x x = self.bn1(x) - x = F.Activation(x, act_type='relu') + act = F.npx.activation if is_np_array() else F.Activation + x = act(x, act_type='relu') if self.downsample: residual = self.downsample(x) x = self.conv1(x) x = self.bn2(x) - x = F.Activation(x, act_type='relu') + x = act(x, act_type='relu') x = self.conv2(x) return x + residual @@ -211,17 +215,18 @@ def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs): def hybrid_forward(self, F, x): residual = x x = self.bn1(x) - x = F.Activation(x, act_type='relu') + act = F.npx.activation if is_np_array() else F.Activation + x = act(x, act_type='relu') if self.downsample: residual = self.downsample(x) x = self.conv1(x) x = self.bn2(x) - x = F.Activation(x, act_type='relu') + x = act(x, act_type='relu') x = self.conv2(x) x = self.bn3(x) - x = F.Activation(x, act_type='relu') + x = act(x, act_type='relu') x = self.conv3(x) return x + residual diff --git a/python/mxnet/gluon/nn/activations.py b/python/mxnet/gluon/nn/activations.py index 6e0e7ca59d97..a3baae004311 100644 --- a/python/mxnet/gluon/nn/activations.py +++ b/python/mxnet/gluon/nn/activations.py @@ -49,9 +49,8 @@ def _alias(self): return self._act_type def hybrid_forward(self, F, x): - if is_np_array(): - F = F.npx - return F.Activation(x, act_type=self._act_type, name='fwd') + act = F.npx.activation if is_np_array() else F.Activation + return act(x, act_type=self._act_type, name='fwd') def __repr__(self): s = '{name}({_act_type})' @@ -91,7 +90,8 @@ def __init__(self, alpha, **kwargs): self._alpha = alpha def hybrid_forward(self, F, x): - return F.LeakyReLU(x, act_type='leaky', slope=self._alpha, name='fwd') + leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU + return leaky_relu(x, act_type='leaky', slope=self._alpha, name='fwd') def __repr__(self): s = '{name}({alpha})' diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index 3c43ac31b06c..a726727e55b1 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -218,10 +218,9 @@ def __init__(self, units, activation=None, use_bias=True, flatten=True, self.act = None def hybrid_forward(self, F, x, weight, bias=None): - if is_np_array(): - F = F.npx - act = F.FullyConnected(x, weight, bias, no_bias=bias is None, num_hidden=self._units, - flatten=self._flatten, name='fwd') + fc = F.npx.fully_connected if is_np_array() else F.FullyConnected + act = fc(x, weight, bias, no_bias=bias is None, num_hidden=self._units, + flatten=self._flatten, name='fwd') if self.act is not None: act = self.act(act) return act @@ -266,7 +265,7 @@ def __init__(self, rate, axes=(), **kwargs): def hybrid_forward(self, F, x): if self._rate > 0: - dropout = F.npx.Dropout if is_np_array() else F.Dropout + dropout = F.npx.dropout if is_np_array() else F.Dropout return dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False) else: copy = F.np.copy if is_np_array() else F.identity @@ -361,10 +360,9 @@ def cast(self, dtype): super(BatchNorm, self).cast(dtype) def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var): - if is_np_array(): - F = F.npx - return F.BatchNorm(x, gamma, beta, running_mean, running_var, - name='fwd', **self._kwargs) + batch_norm = F.npx.batch_norm if is_np_array() else F.BatchNorm + return batch_norm(x, gamma, beta, running_mean, running_var, + name='fwd', **self._kwargs) def __repr__(self): s = '{name}({content}' @@ -416,9 +414,8 @@ def __init__(self, input_dim, output_dim, dtype='float32', allow_deferred_init=True, grad_stype=grad_stype) def hybrid_forward(self, F, x, weight): - if is_np_array(): - F = F.npx - return F.Embedding(x, weight, name='fwd', **self._kwargs) + embedding = F.npx.embedding if is_np_array() else F.Embedding + return embedding(x, weight, name='fwd', **self._kwargs) def __repr__(self): s = '{block_name}({input_dim} -> {output_dim}, {dtype})' @@ -614,9 +611,8 @@ def __init__(self, axis=-1, epsilon=1e-5, center=True, scale=True, allow_deferred_init=True) def hybrid_forward(self, F, data, gamma, beta): - if is_np_array(): - F = F.npx - return F.LayerNorm(data, gamma=gamma, beta=beta, axis=self._axis, eps=self._epsilon) + layer_norm = F.npx.layer_norm if is_np_array() else F.LayerNorm + return layer_norm(data, gamma=gamma, beta=beta, axis=self._axis, eps=self._epsilon) def __repr__(self): s = '{name}({content}' diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py index 3e8516b02180..4682684662cd 100644 --- a/python/mxnet/gluon/nn/conv_layers.py +++ b/python/mxnet/gluon/nn/conv_layers.py @@ -34,8 +34,13 @@ def _infer_weight_shape(op_name, data_shape, kwargs): - op = getattr(symbol, op_name) - sym = op(symbol.var('data', shape=data_shape), **kwargs) + data = symbol.var('data', shape=data_shape) + if is_np_array(): + op = getattr(symbol.npx, op_name) + data = data.as_np_ndarray() + else: + op = getattr(symbol, op_name) + sym = op(data, **kwargs) return sym.infer_shape_partial()[0] @@ -242,9 +247,13 @@ def __init__(self, channels, kernel_size, strides=1, padding=0, dilation=1, if isinstance(kernel_size, numeric_types): kernel_size = (kernel_size,) assert len(kernel_size) == 1, "kernel_size must be a number or a list of 1 ints" + op_name = kwargs.pop('op_name', 'Convolution') + if is_np_array(): + op_name = 'convolution' super(Conv1D, self).__init__( channels, kernel_size, strides, padding, dilation, groups, layout, - in_channels, activation, use_bias, weight_initializer, bias_initializer, **kwargs) + in_channels, activation, use_bias, weight_initializer, bias_initializer, + op_name, **kwargs) class Conv2D(_Conv): @@ -322,9 +331,13 @@ def __init__(self, channels, kernel_size, strides=(1, 1), padding=(0, 0), if isinstance(kernel_size, numeric_types): kernel_size = (kernel_size,)*2 assert len(kernel_size) == 2, "kernel_size must be a number or a list of 2 ints" + op_name = kwargs.pop('op_name', 'Convolution') + if is_np_array(): + op_name = 'convolution' super(Conv2D, self).__init__( channels, kernel_size, strides, padding, dilation, groups, layout, - in_channels, activation, use_bias, weight_initializer, bias_initializer, **kwargs) + in_channels, activation, use_bias, weight_initializer, bias_initializer, + op_name, **kwargs) class Conv3D(_Conv): @@ -403,9 +416,13 @@ def __init__(self, channels, kernel_size, strides=(1, 1, 1), padding=(0, 0, 0), if isinstance(kernel_size, numeric_types): kernel_size = (kernel_size,)*3 assert len(kernel_size) == 3, "kernel_size must be a number or a list of 3 ints" + op_name = kwargs.pop('op_name', 'Convolution') + if is_np_array(): + op_name = 'convolution' super(Conv3D, self).__init__( channels, kernel_size, strides, padding, dilation, groups, layout, - in_channels, activation, use_bias, weight_initializer, bias_initializer, **kwargs) + in_channels, activation, use_bias, weight_initializer, bias_initializer, + op_name, **kwargs) class Conv1DTranspose(_Conv): @@ -487,10 +504,13 @@ def __init__(self, channels, kernel_size, strides=1, padding=0, output_padding=0 output_padding = (output_padding,) assert len(kernel_size) == 1, "kernel_size must be a number or a list of 1 ints" assert len(output_padding) == 1, "output_padding must be a number or a list of 1 ints" + op_name = kwargs.pop('op_name', 'Deconvolution') + if is_np_array(): + op_name = 'deconvolution' super(Conv1DTranspose, self).__init__( channels, kernel_size, strides, padding, dilation, groups, layout, in_channels, activation, use_bias, weight_initializer, - bias_initializer, op_name='Deconvolution', adj=output_padding, **kwargs) + bias_initializer, op_name=op_name, adj=output_padding, **kwargs) self.outpad = output_padding @@ -578,10 +598,13 @@ def __init__(self, channels, kernel_size, strides=(1, 1), padding=(0, 0), output_padding = (output_padding,)*2 assert len(kernel_size) == 2, "kernel_size must be a number or a list of 2 ints" assert len(output_padding) == 2, "output_padding must be a number or a list of 2 ints" + op_name = kwargs.pop('op_name', 'Deconvolution') + if is_np_array(): + op_name = 'deconvolution' super(Conv2DTranspose, self).__init__( channels, kernel_size, strides, padding, dilation, groups, layout, in_channels, activation, use_bias, weight_initializer, - bias_initializer, op_name='Deconvolution', adj=output_padding, **kwargs) + bias_initializer, op_name=op_name, adj=output_padding, **kwargs) self.outpad = output_padding @@ -670,10 +693,13 @@ def __init__(self, channels, kernel_size, strides=(1, 1, 1), padding=(0, 0, 0), output_padding = (output_padding,)*3 assert len(kernel_size) == 3, "kernel_size must be a number or a list of 3 ints" assert len(output_padding) == 3, "output_padding must be a number or a list of 3 ints" + op_name = kwargs.pop('op_name', 'Deconvolution') + if is_np_array(): + op_name = 'deconvolution' super(Conv3DTranspose, self).__init__( channels, kernel_size, strides, padding, dilation, groups, layout, in_channels, activation, use_bias, weight_initializer, bias_initializer, - op_name='Deconvolution', adj=output_padding, **kwargs) + op_name=op_name, adj=output_padding, **kwargs) self.outpad = output_padding @@ -700,9 +726,8 @@ def _alias(self): return 'pool' def hybrid_forward(self, F, x): - if is_np_array(): - F = F.npx - return F.Pooling(x, name='fwd', **self._kwargs) + pooling = F.npx.pooling if is_np_array() else F.Pooling + return pooling(x, name='fwd', **self._kwargs) def __repr__(self): s = '{name}(size={kernel}, stride={stride}, padding={pad}, ceil_mode={ceil_mode}' diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 1104b1e2df45..9807c5e33108 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -284,7 +284,7 @@ def _forward_kernel(self, F, inputs, states, sequence_length, **kwargs): else: rnn_args = states - rnn_fn = F.npx.RNN if is_np_array() else F.RNN + rnn_fn = F.npx.rnn if is_np_array() else F.RNN rnn = rnn_fn(inputs, params, *rnn_args, use_sequence_length=self._use_sequence_length, state_size=self._hidden_size, projection_size=self._projection_size, num_layers=self._num_layers, bidirectional=self._dir == 2, diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 542a3c6fdbb8..2822c7019a28 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -86,12 +86,19 @@ def split_data(data, num_slice, batch_axis=0, even_split=True): slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else data[i*step:size] for i in range(num_slice)] elif even_split: - slices = ndarray.split(data, num_outputs=num_slice, axis=batch_axis) + if is_np_array(): + slices = _mx_np.split(data, indices_or_sections=num_slice, axis=batch_axis) + else: + slices = ndarray.split(data, num_outputs=num_slice, axis=batch_axis) else: - slices = [ndarray.slice_axis(data, batch_axis, i*step, (i+1)*step) - if i < num_slice - 1 else - ndarray.slice_axis(data, batch_axis, i*step, size) - for i in range(num_slice)] + if is_np_array(): + indices = [step * i for i in range(1, num_slice)] + slices = _mx_np.split(data, indices_or_sections=indices, axis=batch_axis) + else: + slices = [ndarray.slice_axis(data, batch_axis, i*step, (i+1)*step) + if i < num_slice - 1 else + ndarray.slice_axis(data, batch_axis, i*step, size) + for i in range(num_slice)] return slices @@ -101,7 +108,7 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True): Parameters ---------- - data : NDArray + data : NDArray or ndarray A batch of data. ctx_list : list of Context A list of Contexts. @@ -112,7 +119,7 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True): Returns ------- - list of NDArray + list of NDArrays or ndarrays Each corresponds to a context in `ctx_list`. """ array_fn = _mx_np.array if is_np_array() else ndarray.array @@ -121,11 +128,7 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True): if len(ctx_list) == 1: return [data.as_in_context(ctx_list[0])] - # TODO(junwu): temp solution for supporting np.ndarray - # rewrite this using np ops slices = split_data(data, len(ctx_list), batch_axis, even_split) - if is_np_array(): - slices = [i.as_np_ndarray() for i in slices] return [i.as_in_context(ctx) for i, ctx in zip(slices, ctx_list)] diff --git a/python/mxnet/image/detection.py b/python/mxnet/image/detection.py index a70e5723072f..f3b551b53893 100644 --- a/python/mxnet/image/detection.py +++ b/python/mxnet/image/detection.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-import +# pylint: disable=unused-import, too-many-lines """Read images and perform augmentations for object detection.""" from __future__ import absolute_import, print_function @@ -34,6 +34,8 @@ from .image import RandomOrderAug, ColorJitterAug, LightingAug, ColorNormalizeAug from .image import ResizeAug, ForceResizeAug, CastAug, HueJitterAug, RandomGrayAug from .image import fixed_crop, ImageIter, Augmenter +from ..util import is_np_array +from .. import numpy as _mx_np # pylint: disable=reimported class DetAugmenter(object): @@ -762,6 +764,7 @@ def _batchify(self, batch_data, batch_label, start=0): """Override the helper function for batchifying data""" i = start batch_size = self.batch_size + array_fn = _mx_np.array if is_np_array() else nd.array try: while i < batch_size: label, s = self.next_sample() @@ -778,7 +781,7 @@ def _batchify(self, batch_data, batch_label, start=0): assert i < batch_size, 'Batch size must be multiples of augmenter output length' batch_data[i] = self.postprocess_data(datum) num_object = label.shape[0] - batch_label[i][0:num_object] = nd.array(label) + batch_label[i][0:num_object] = array_fn(label) if num_object < batch_label[i].shape[0]: batch_label[i][num_object:] = -1 i += 1 @@ -801,8 +804,14 @@ def next(self): batch_label = self._cache_label i = self._cache_idx else: - batch_data = nd.zeros((batch_size, c, h, w)) - batch_label = nd.empty(self.provide_label[0][1]) + if is_np_array(): + zeros_fn = _mx_np.zeros + empty_fn = _mx_np.empty + else: + zeros_fn = nd.zeros + empty_fn = nd.empty + batch_data = zeros_fn((batch_size, c, h, w)) + batch_label = empty_fn(self.provide_label[0][1]) batch_label[:] = -1 i = self._batchify(batch_data, batch_label) # calculate the padding diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index a142282c83a6..c48e2df91739 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -28,6 +28,7 @@ import json import warnings import numpy as np +from .. import numpy as _mx_np # pylint: disable=reimported try: @@ -40,6 +41,8 @@ from ..ndarray import _internal from .. import io from .. import recordio +from .. util import is_np_array +from ..ndarray.numpy import _internal as _npi def imread(filename, *args, **kwargs): @@ -80,7 +83,11 @@ def imread(filename, *args, **kwargs): >>> mx.img.imread("flower.jpg", to_rgb=0) """ - return _internal._cvimread(filename, *args, **kwargs) + if is_np_array(): + read_fn = _npi.cvimread + else: + read_fn = _internal._cvimread + return read_fn(filename, *args, **kwargs) def imresize(src, w, h, *args, **kwargs): @@ -137,7 +144,8 @@ def imresize(src, w, h, *args, **kwargs): >>> new_image """ - return _internal._cvimresize(src, w, h, *args, **kwargs) + resize_fn = _npi.cvimresize if is_np_array() else _internal._cvimresize + return resize_fn(src, w, h, *args, **kwargs) def imdecode(buf, *args, **kwargs): @@ -193,9 +201,11 @@ def imdecode(buf, *args, **kwargs): if sys.version_info[0] == 3 and not isinstance(buf, (bytes, bytearray, np.ndarray)): raise ValueError('buf must be of type bytes, bytearray or numpy.ndarray,' 'if you would like to input type str, please convert to bytes') - buf = nd.array(np.frombuffer(buf, dtype=np.uint8), dtype=np.uint8) + array_fn = _mx_np.array if is_np_array() else nd.array + buf = array_fn(np.frombuffer(buf, dtype=np.uint8), dtype=np.uint8) - return _internal._cvimdecode(buf, *args, **kwargs) + cvimdecode = _npi.cvimdecode if is_np_array() else _internal._cvimdecode + return cvimdecode(buf, *args, **kwargs) def scale_down(src_size, size): @@ -428,7 +438,7 @@ def fixed_crop(src, x0, y0, w, h, size=None, interp=2): NDArray An `NDArray` containing the cropped image. """ - out = nd.slice(src, begin=(y0, x0, 0), end=(y0 + h, x0 + w, int(src.shape[2]))) + out = src[y0:y0+h, x0:x0+w] if size is not None and (w, h) != size: sizes = (h, w, size[1], size[0]) out = imresize(out, *size, interp=_get_interp_method(interp, sizes)) @@ -1206,6 +1216,7 @@ def __init__(self, batch_size, data_shape, label_width=1, else: self.imgrec = None + array_fn = _mx_np.array if is_np_array() else nd.array if path_imglist: logging.info('%s: loading image list %s...', class_name, path_imglist) with open(path_imglist) as fin: @@ -1213,7 +1224,7 @@ def __init__(self, batch_size, data_shape, label_width=1, imgkeys = [] for line in iter(fin.readline, ''): line = line.strip().split('\t') - label = nd.array(line[1:-1], dtype=dtype) + label = array_fn(line[1:-1], dtype=dtype) key = int(line[0]) imglist[key] = (label, line[-1]) imgkeys.append(key) @@ -1227,11 +1238,11 @@ def __init__(self, batch_size, data_shape, label_width=1, key = str(index) # pylint: disable=redefined-variable-type index += 1 if len(img) > 2: - label = nd.array(img[:-1], dtype=dtype) + label = array_fn(img[:-1], dtype=dtype) elif isinstance(img[0], numeric_types): - label = nd.array([img[0]], dtype=dtype) + label = array_fn([img[0]], dtype=dtype) else: - label = nd.array(img[0], dtype=dtype) + label = array_fn(img[0], dtype=dtype) result[key] = (label, img[-1]) imgkeys.append(str(key)) self.imglist = result @@ -1367,8 +1378,14 @@ def next(self): i = self._cache_idx # clear the cache data else: - batch_data = nd.zeros((batch_size, c, h, w)) - batch_label = nd.empty(self.provide_label[0][1]) + if is_np_array(): + zeros_fn = _mx_np.zeros + empty_fn = _mx_np.empty + else: + zeros_fn = nd.zeros + empty_fn = nd.empty + batch_data = zeros_fn((batch_size, c, h, w)) + batch_label = empty_fn(self.provide_label[0][1]) i = self._batchify(batch_data, batch_label) # calculate the padding pad = batch_size - i @@ -1445,4 +1462,7 @@ def augmentation_transform(self, data): def postprocess_data(self, datum): """Final postprocessing step before image is loaded into the batch.""" - return nd.transpose(datum, axes=(2, 0, 1)) + if is_np_array(): + return datum.transpose(2, 0, 1) + else: + return nd.transpose(datum, axes=(2, 0, 1)) diff --git a/python/mxnet/ndarray/numpy_extension/__init__.py b/python/mxnet/ndarray/numpy_extension/__init__.py index a718274ae9ed..5be34ac9b3d5 100644 --- a/python/mxnet/ndarray/numpy_extension/__init__.py +++ b/python/mxnet/ndarray/numpy_extension/__init__.py @@ -18,6 +18,7 @@ """Module for the ops not belonging to the official numpy package.""" from . import _op +from . import image from . import _register from ._op import * # pylint: disable=wildcard-import diff --git a/python/mxnet/ndarray/numpy_extension/image.py b/python/mxnet/ndarray/numpy_extension/image.py new file mode 100644 index 000000000000..b3bd27fc503c --- /dev/null +++ b/python/mxnet/ndarray/numpy_extension/image.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Image pre-processing operators.""" + +__all__ = [] diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py index 7a9a2f60b53f..1994148d14d1 100644 --- a/python/mxnet/numpy/__init__.py +++ b/python/mxnet/numpy/__init__.py @@ -29,5 +29,6 @@ from .function_base import * # pylint: disable=wildcard-import from .stride_tricks import * # pylint: disable=wildcard-import from .io import * # pylint: disable=wildcard-import +from .arrayprint import * # pylint: disable=wildcard-import __all__ = [] diff --git a/python/mxnet/numpy/arrayprint.py b/python/mxnet/numpy/arrayprint.py new file mode 100644 index 000000000000..9be7faf1f602 --- /dev/null +++ b/python/mxnet/numpy/arrayprint.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""ndarray print format controller.""" + +from __future__ import absolute_import, print_function + +import numpy as onp +from ..util import set_module + +__all__ = ['set_printoptions'] + + +@set_module('mxnet.numpy') +def set_printoptions(precision=None, threshold=None, **kwarg): + """ + Set printing options. + + These options determine the way floating point numbers and arrays are displayed. + + Parameters + ---------- + precision : int or None, optional + Number of digits of precision for floating point output (default 8). + May be `None` if `floatmode` is not `fixed`, to print as many digits as + necessary to uniquely specify the value. + threshold : int, optional + Total number of array elements which trigger summarization + rather than full repr (default 1000). + + Examples + -------- + Floating point precision can be set: + + >>> np.set_printoptions(precision=4) + >>> print(np.array([1.123456789])) + [ 1.1235] + + Long arrays can be summarised: + + >>> np.set_printoptions(threshold=5) + >>> print(np.arange(10)) + [0. 1. 2. ... 7. 8. 9.] + """ + if kwarg: + raise NotImplementedError('mxnet.numpy.set_printoptions only supports parameters' + ' precision and threshold for now.') + onp.set_printoptions(precision=precision, threshold=threshold, **kwarg) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 2a37af7e17bc..9d9966b8066e 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -423,8 +423,53 @@ def as_np_ndarray(self): return self def __repr__(self): - """Returns a string representation of the array.""" + """ + Returns a string representation of the array. The dtype of the ndarray will not + be appended to the string if it is `float32`. The context of the ndarray will + be appended for devices other than CPU. + + Examples + -------- + >>> from mxnet import np, npx + >>> a = np.random.uniform(size=(2, 3)) + >>> a + array([[0.5488135 , 0.5928446 , 0.71518934], + [0.84426576, 0.60276335, 0.8579456 ]]) + >>> print(a) + [[0.5488135 0.5928446 0.71518934] + [0.84426576 0.60276335 0.8579456 ]] + >>> a.dtype + + >>> b = a.astype(np.float64) + >>> b + array([[0.54881352, 0.59284461, 0.71518934], + [0.84426576, 0.60276335, 0.85794562]], dtype=float64) + >>> print(b) + [[0.54881352 0.59284461 0.71518934] + [0.84426576 0.60276335 0.85794562]] + >>> b.dtype + + >>> c = a.copyto(npx.gpu(0)) + >>> c + array([[0.5488135 , 0.5928446 , 0.71518934], + [0.84426576, 0.60276335, 0.8579456 ]], ctx=gpu(0)) + >>> print(c) + [[0.5488135 0.5928446 0.71518934] + [0.84426576 0.60276335 0.8579456 ]] @gpu(0) + >>> d = b.copyto(npx.gpu(0)) + >>> d + array([[0.54881352, 0.59284461, 0.71518934], + [0.84426576, 0.60276335, 0.85794562]], dtype=float64, ctx=gpu(0)) + >>> print(d) + [[0.54881352 0.59284461 0.71518934] + [0.84426576 0.60276335 0.85794562]] @gpu(0) + """ array_str = self.asnumpy().__repr__() + dtype = self.dtype + if dtype == _np.float64: + array_str = array_str[:-1] + ', dtype=float64)' + elif dtype == _np.float32: + array_str = array_str[:array_str.rindex(', dtype=')] + ')' context = self.context if context.device_type == 'cpu': return array_str @@ -814,11 +859,7 @@ def tile(self, *args, **kwargs): raise AttributeError('mxnet.numpy.ndarray object has no attribute tile') def transpose(self, *axes): # pylint: disable=arguments-differ - """Convenience fluent method for :py:func:`transpose`. - - The arguments are the same as for :py:func:`transpose`, with - this array as data. - """ + """Permute the dimensions of an array.""" return _mx_np_op.transpose(self, axes=axes if len(axes) != 0 else None) def flip(self, *args, **kwargs): diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py index d80f0cc0f1f5..6e89c004f6a4 100644 --- a/python/mxnet/numpy_extension/__init__.py +++ b/python/mxnet/numpy_extension/__init__.py @@ -21,6 +21,7 @@ from __future__ import absolute_import from . import _op +from . import image from . import _register from ._op import * # pylint: disable=wildcard-import from ..context import * # pylint: disable=wildcard-import @@ -30,5 +31,6 @@ from ..util import set_np, use_np, reset_np from ..ndarray import waitall from .utils import * # pylint: disable=wildcard-import +from .random import * # pylint: disable=wildcard-import __all__ = [] diff --git a/python/mxnet/numpy_extension/image.py b/python/mxnet/numpy_extension/image.py new file mode 100644 index 000000000000..b3bd27fc503c --- /dev/null +++ b/python/mxnet/numpy_extension/image.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Image pre-processing operators.""" + +__all__ = [] diff --git a/python/mxnet/numpy_extension/random.py b/python/mxnet/numpy_extension/random.py new file mode 100644 index 000000000000..bfe2270358bb --- /dev/null +++ b/python/mxnet/numpy_extension/random.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Namespace for ops used in imperative programming.""" + +from __future__ import absolute_import +from .. import random as _mx_rand + + +__all__ = ['seed'] + + +def seed(seed, ctx='all'): # pylint: disable=redefined-outer-name + """Seeds the random number generators in MXNet. + + This affects the behavior of modules in MXNet that uses random number generators, + like the dropout operator and `ndarray`'s random sampling operators. + + Parameters + ---------- + seed : int + The random number seed. + + ctx : Context + The device context of the generator. The default is "all" which means seeding random + number generators of all devices. + + Notes + ----- + Random number generators in MXNet are device specific. + `mx.random.seed(seed_state)` sets the state of each generator using `seed_state` and the + device id. Therefore, random numbers generated from different devices can be different + even if they are seeded using the same seed. + + To produce identical random number sequences independent of the device id, + set optional `ctx` argument. This produces the same sequence of random numbers independent + of the device id, but the sequence can be different on different kind of devices as MXNet's + random number generators for CPU and GPU use different algorithms. + + Example + ------- + >>> from mxnet import np, npx + >>> npx.set_np() + >>> npx.random.seed(0) + >>> np.random.uniform() + array(0.5488135) + >>> npx.random.seed(128) + >>> np.random.uniform() + array(0.03812965) + >>> npx.random.seed(128) + >>> np.random.uniform() + array(0.03812965) + >>> npx.random.seed(128) + >>> np.random.uniform(ctx=npx.gpu(0)) + array(0.9894903, ctx=gpu(0)) + >>> npx.random.seed(128) + >>> np.random.uniform(ctx=npx.gpu(0)) + array(0.9894903, ctx=gpu(0)) + """ + _mx_rand.seed(seed_state=seed, ctx=ctx) diff --git a/python/mxnet/symbol/numpy_extension/__init__.py b/python/mxnet/symbol/numpy_extension/__init__.py index a718274ae9ed..5be34ac9b3d5 100644 --- a/python/mxnet/symbol/numpy_extension/__init__.py +++ b/python/mxnet/symbol/numpy_extension/__init__.py @@ -18,6 +18,7 @@ """Module for the ops not belonging to the official numpy package.""" from . import _op +from . import image from . import _register from ._op import * # pylint: disable=wildcard-import diff --git a/python/mxnet/symbol/numpy_extension/image.py b/python/mxnet/symbol/numpy_extension/image.py new file mode 100644 index 000000000000..b3bd27fc503c --- /dev/null +++ b/python/mxnet/symbol/numpy_extension/image.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Image pre-processing operators.""" + +__all__ = [] diff --git a/src/io/image_io.cc b/src/io/image_io.cc index c0357998f31c..db9ac7682287 100644 --- a/src/io/image_io.cc +++ b/src/io/image_io.cc @@ -357,6 +357,7 @@ inline void copyMakeBorder(const nnvm::NodeAttrs& attrs, } NNVM_REGISTER_OP(_cvimdecode) +.add_alias("_npi_cvimdecode") .describe("Decode image with OpenCV. \n" "Note: return image in RGB by default, " "instead of OpenCV's default BGR.") @@ -368,6 +369,7 @@ NNVM_REGISTER_OP(_cvimdecode) .add_arguments(ImdecodeParam::__FIELDS__()); NNVM_REGISTER_OP(_cvimread) +.add_alias("_npi_cvimread") .describe("Read and decode image with OpenCV. \n" "Note: return image in RGB by default, " "instead of OpenCV's default BGR.") @@ -378,6 +380,7 @@ NNVM_REGISTER_OP(_cvimread) .add_arguments(ImreadParam::__FIELDS__()); NNVM_REGISTER_OP(_cvimresize) +.add_alias("_npi_cvimresize") .describe("Resize image with OpenCV. \n") .set_num_inputs(1) .set_num_outputs(1) diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index f10f5db1607a..d8cb9317342e 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1728,7 +1728,7 @@ bool NDArray::Load(dmlc::Stream *strm) { CHECK(!Imperative::Get()->is_np_shape()) << "ndarray was not saved in np shape semantics, but being loaded in np shape semantics." " Please turn off np shape semantics in Python using `with np_shape(False)`" - " to scope of the code of loading the ndarray."; + " to scope the code of loading the ndarray."; } if (magic != NDARRAY_V2_MAGIC && magic != NDARRAY_V3_MAGIC) { return LegacyLoad(strm, magic); diff --git a/src/operator/contrib/multibox_detection.cc b/src/operator/contrib/multibox_detection.cc index 37bb5a500d71..cb2dfe34bfc3 100644 --- a/src/operator/contrib/multibox_detection.cc +++ b/src/operator/contrib/multibox_detection.cc @@ -220,5 +220,9 @@ MXNET_REGISTER_OP_PROPERTY(_contrib_MultiBoxDetection, MultiBoxDetectionProp) .add_argument("loc_pred", "NDArray-or-Symbol", "Location regression predictions.") .add_argument("anchor", "NDArray-or-Symbol", "Multibox prior anchor boxes") .add_arguments(MultiBoxDetectionParam::__FIELDS__()); + +NNVM_REGISTER_OP(_contrib_MultiBoxDetection) +.add_alias("_npx_multibox_detection"); + } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/multibox_prior.cc b/src/operator/contrib/multibox_prior.cc index 2ad173a2dd93..66fd2c11517a 100644 --- a/src/operator/contrib/multibox_prior.cc +++ b/src/operator/contrib/multibox_prior.cc @@ -100,5 +100,8 @@ MXNET_REGISTER_OP_PROPERTY(_contrib_MultiBoxPrior, MultiBoxPriorProp) .add_arguments(MultiBoxPriorParam::__FIELDS__()) .describe("Generate prior(anchor) boxes from data, sizes and ratios."); +NNVM_REGISTER_OP(_contrib_MultiBoxPrior) +.add_alias("_npx_multibox_prior"); + } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/multibox_target.cc b/src/operator/contrib/multibox_target.cc index a1808c5a7c81..feab3977f82c 100644 --- a/src/operator/contrib/multibox_target.cc +++ b/src/operator/contrib/multibox_target.cc @@ -307,5 +307,9 @@ MXNET_REGISTER_OP_PROPERTY(_contrib_MultiBoxTarget, MultiBoxTargetProp) .add_argument("label", "NDArray-or-Symbol", "Object detection labels.") .add_argument("cls_pred", "NDArray-or-Symbol", "Class predictions.") .add_arguments(MultiBoxTargetParam::__FIELDS__()); + +NNVM_REGISTER_OP(_contrib_MultiBoxTarget) +.add_alias("_npx_multibox_target"); + } // namespace op } // namespace mxnet diff --git a/src/operator/image/crop.cc b/src/operator/image/crop.cc index 52d2f11a464b..6067f89d7033 100644 --- a/src/operator/image/crop.cc +++ b/src/operator/image/crop.cc @@ -35,6 +35,7 @@ namespace image { DMLC_REGISTER_PARAMETER(CropParam); NNVM_REGISTER_OP(_image_crop) +.add_alias("_npx__image_crop") .describe(R"code(Crop an image NDArray of shape (H x W x C) or (N x H x W x C) to the given size. Example: diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc index 34f4cb4d395c..0c4603ecc475 100644 --- a/src/operator/image/image_random.cc +++ b/src/operator/image/image_random.cc @@ -39,6 +39,7 @@ DMLC_REGISTER_PARAMETER(RandomLightingParam); DMLC_REGISTER_PARAMETER(RandomColorJitterParam); NNVM_REGISTER_OP(_image_to_tensor) +.add_alias("_npx__image_to_tensor") .describe(R"code(Converts an image NDArray of shape (H x W x C) or (N x H x W x C) with values in the range [0, 255] to a tensor NDArray of shape (C x H x W) or (N x C x H x W) with values in the range [0, 1] @@ -102,6 +103,7 @@ with values in the range [0, 1] .add_argument("data", "NDArray-or-Symbol", "Input ndarray"); NNVM_REGISTER_OP(_image_normalize) +.add_alias("_npx__image_normalize") .describe(R"code(Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and standard deviation. @@ -189,28 +191,34 @@ NNVM_REGISTER_OP(_backward_image_normalize) .set_attr("FCompute", NormalizeOpBackward); MXNET_REGISTER_IMAGE_AUG_OP(_image_flip_left_right) +.add_alias("_npx__image_flip_left_right") .describe(R"code()code" ADD_FILELINE) .set_attr("FCompute", FlipLeftRight); MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_flip_left_right) +.add_alias("_npx__image_random_flip_left_right") .describe(R"code()code" ADD_FILELINE) .set_attr("FCompute", RandomFlipLeftRight); MXNET_REGISTER_IMAGE_AUG_OP(_image_flip_top_bottom) +.add_alias("_npx__image_flip_top_bottom") .describe(R"code()code" ADD_FILELINE) .set_attr("FCompute", FlipTopBottom); MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_flip_top_bottom) +.add_alias("_npx__image_random_flip_top_bottom") .describe(R"code()code" ADD_FILELINE) .set_attr("FCompute", RandomFlipTopBottom); MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_brightness) +.add_alias("_npx__image_random_brightness") .describe(R"code()code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_attr("FCompute", RandomBrightness) .add_arguments(RandomEnhanceParam::__FIELDS__()); MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_contrast) +.add_alias("_npx__image_random_contrast") .describe(R"code()code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_attr("FCompute", RandomContrast) @@ -218,6 +226,7 @@ MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_contrast) MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_saturation) +.add_alias("_npx__image_random_saturation") .describe(R"code()code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_attr("FCompute", RandomSaturation) @@ -225,6 +234,7 @@ MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_saturation) MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_hue) +.add_alias("_npx__image_random_hue") .describe(R"code()code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_attr("FCompute", RandomHue) @@ -232,6 +242,7 @@ MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_hue) MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_color_jitter) +.add_alias("_npx__image_random_color_jitter") .describe(R"code()code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_attr("FCompute", RandomColorJitter) @@ -239,6 +250,7 @@ MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_color_jitter) MXNET_REGISTER_IMAGE_AUG_OP(_image_adjust_lighting) +.add_alias("_npx__image_adjust_lighting") .describe(R"code(Adjust the lighting level of the input. Follow the AlexNet style.)code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_attr("FCompute", AdjustLighting) @@ -246,6 +258,7 @@ MXNET_REGISTER_IMAGE_AUG_OP(_image_adjust_lighting) MXNET_REGISTER_IMAGE_RND_AUG_OP(_image_random_lighting) +.add_alias("_npx__image_random_lighting") .describe(R"code(Randomly add PCA noise. Follow the AlexNet style.)code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_attr("FCompute", RandomLighting) diff --git a/src/operator/image/resize.cc b/src/operator/image/resize.cc index d93769faa8b3..d2397ea72685 100644 --- a/src/operator/image/resize.cc +++ b/src/operator/image/resize.cc @@ -34,6 +34,7 @@ namespace image { DMLC_REGISTER_PARAMETER(ResizeParam); NNVM_REGISTER_OP(_image_resize) +.add_alias("_npx__image_resize") .describe(R"code(Resize an image NDArray of shape (H x W x C) or (N x H x W x C) to the given size Example: diff --git a/src/operator/leaky_relu.cc b/src/operator/leaky_relu.cc index 214e41a84611..c25833b799d0 100644 --- a/src/operator/leaky_relu.cc +++ b/src/operator/leaky_relu.cc @@ -71,6 +71,7 @@ The following modified ReLU Activation functions are supported: .add_arguments(LeakyReLUParam::__FIELDS__()); NNVM_REGISTER_OP(LeakyReLU) +.add_alias("_npx_leaky_relu") .set_attr("FSetInputVarAttrOnCompose", [](const nnvm::NodeAttrs& attrs, nnvm::NodePtr var, const int index) { if (index == 1 && var->attrs.dict.find("__init__") == var->attrs.dict.end()) { diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc index 3d668c82d6aa..5abb6670c9b0 100644 --- a/src/operator/nn/activation.cc +++ b/src/operator/nn/activation.cc @@ -154,7 +154,7 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs, MXNET_OPERATOR_REGISTER_UNARY(Activation) -.add_alias("_npx_Activation") +.add_alias("_npx_activation") .describe(R"code(Applies an activation function element-wise to the input. The following activation functions are supported: diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 030f58940b04..6382d46d272d 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -520,7 +520,7 @@ std::vector BatchNormGrad(const nnvm::NodePtr& n, } NNVM_REGISTER_OP(BatchNorm) -.add_alias("_npx_BatchNorm") +.add_alias("_npx_batch_norm") .describe(R"code(Batch normalization. Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 6ab388a39b87..32ed93e4a463 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -397,7 +397,7 @@ struct ConvolutionGrad { }; NNVM_REGISTER_OP(Convolution) -.add_alias("_npx_Convolution") +.add_alias("_npx_convolution") .describe(R"code(Compute *N*-D convolution on *(N+2)*-D input. In the 2-D convolution, given input data with shape *(batch_size, diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index 09b255d009e0..9f461f4e9de3 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -408,6 +408,7 @@ struct DeconvolutionGrad { DMLC_REGISTER_PARAMETER(DeconvolutionParam); NNVM_REGISTER_OP(Deconvolution) +.add_alias("_npx_deconvolution") .describe("Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of the " "input tensor. This operation can be seen as the gradient of Convolution operation with " "respect to its input. Convolution usually reduces the size of the input. Transposed " diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index 72ba422932ef..29f13a4ffe97 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -65,7 +65,7 @@ struct DropoutGrad { DMLC_REGISTER_PARAMETER(DropoutParam); NNVM_REGISTER_OP(Dropout) -.add_alias("_npx_Dropout") +.add_alias("_npx_dropout") .describe(R"(Applies dropout operation to input array. - During training, each element of the input is set to zero with probability p. diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index 9f30ed2353a0..06ad6d034398 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -244,7 +244,7 @@ DMLC_REGISTER_PARAMETER(FullyConnectedParam); NNVM_REGISTER_OP(FullyConnected) MXNET_ADD_SPARSE_OP_ALIAS(FullyConnected) -.add_alias("_npx_FullyConnected") +.add_alias("_npx_fully_connected") .describe(R"code(Applies a linear transformation: :math:`Y = XW^T + b`. If ``flatten`` is set to be true, then the shapes are: diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index 7c6ddcb9fce1..0b53d5091194 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -127,7 +127,7 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, } NNVM_REGISTER_OP(LayerNorm) -.add_alias("_npx_LayerNorm") +.add_alias("_npx_layer_norm") .describe(R"code(Layer normalization. Normalizes the channels of the input tensor by mean and variance, and applies a scale ``gamma`` as diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 0df5827b76eb..485fc1345dfd 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -364,7 +364,7 @@ inline static bool BackwardPoolingStorageType(const nnvm::NodeAttrs &attrs, DMLC_REGISTER_PARAMETER(PoolingParam); NNVM_REGISTER_OP(Pooling) -.add_alias("_npx_Pooling") +.add_alias("_npx_pooling") .describe(R"code(Performs pooling on the input. The shapes for 1-D pooling are diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index 2ffa3b8f85aa..fe5aeb0457aa 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -34,14 +34,9 @@ bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); - const int itype = in_attrs->at(0); - if (itype == -1) return false; - auto is_float = [](const int dtype) { - return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16; - }; - CHECK(is_float(itype)) << "numpy binary scalar op currently only supports float dtype"; - TYPE_ASSIGN_CHECK(*out_attrs, 0, itype); - return true; + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + return in_attrs->at(0) != -1; } #define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 58f190ad2d4f..244e39335a91 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -634,7 +634,7 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, #endif NNVM_REGISTER_OP(RNN) -.add_alias("_npx_RNN") +.add_alias("_npx_rnn") .describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support. diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc index bba3bea5ce6a..56c872500822 100644 --- a/src/operator/roi_pooling.cc +++ b/src/operator/roi_pooling.cc @@ -230,5 +230,9 @@ Example:: "corners of designated region of interest. `batch_index` indicates the index of corresponding " "image in the input array") .add_arguments(ROIPoolingParam::__FIELDS__()); + +NNVM_REGISTER_OP(ROIPooling) +.add_alias("_npx_roi_pooling"); + } // namespace op } // namespace mxnet diff --git a/src/operator/sequence_mask.cc b/src/operator/sequence_mask.cc index ca58be19d730..d7731026ce21 100644 --- a/src/operator/sequence_mask.cc +++ b/src/operator/sequence_mask.cc @@ -192,7 +192,7 @@ Example:: .add_arguments(SequenceMaskParam::__FIELDS__()); NNVM_REGISTER_OP(SequenceMask) -.add_alias("_npx_SequenceMask"); +.add_alias("_npx_sequence_mask"); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc index f027665a549b..3a687c2aa062 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc @@ -84,7 +84,8 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_hypot_scalar) cpu, mshadow_op::hypot_grad_left>); NNVM_REGISTER_OP(smooth_l1) - .describe(R"code(Calculate Smooth L1 Loss(lhs, scalar) by summing +.add_alias("_npx_smooth_l1") +.describe(R"code(Calculate Smooth L1 Loss(lhs, scalar) by summing .. math:: diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index a955508a9089..3dffc73884a9 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -650,6 +650,7 @@ Example:: DMLC_REGISTER_PARAMETER(CastParam); NNVM_REGISTER_OP(Cast) .add_alias("cast") +.add_alias("_npx_cast") .describe(R"code(Casts all elements of the input to a new type. .. note:: ``Cast`` is deprecated. Use ``cast`` instead. diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index f229fefa731a..ad4e54db54f1 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -466,7 +466,7 @@ DMLC_REGISTER_PARAMETER(ScatterNDParam); NNVM_REGISTER_OP(Embedding) MXNET_ADD_SPARSE_OP_ALIAS(Embedding) -.add_alias("_npx_Embedding") +.add_alias("_npx_embedding") .describe(R"code(Maps integer indices to vector representations (embeddings). This operator maps words to real-valued vectors in a high-dimensional space,