Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FFI] part2: npx.pick, npx.convolution, npx.deconvolution #20101

Merged
merged 3 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,8 @@ def write_all_str(module_file, module_all_list):
_NP_EXT_OP_SUBMODULE_LIST = ['_image_', '_random_']
_NP_EXT_OP_IMPLEMENTED_SET = {'_npx_softmax', '_npx_log_softmax', '_npx_masked_softmax',
'_npx_masked_log_softmax', '_npx_activation',
'_npx_batch_norm', '_npx_fully_connected'}
'_npx_batch_norm', '_npx_fully_connected', '_npx_pick',
'_npx_convolution', '_npx_deconvolution'}

_NP_INTERNAL_OP_PREFIX = '_npi_'

Expand Down
292 changes: 291 additions & 1 deletion python/mxnet/ndarray/numpy_extension/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@


__all__ = ['softmax', 'log_softmax', 'masked_softmax', 'masked_log_softmax',
'activation', 'batch_norm', 'fully_connected']
'activation', 'batch_norm', 'fully_connected', 'pick', 'convolution',
'deconvolution']


# pylint: disable=too-many-arguments
Expand Down Expand Up @@ -418,3 +419,292 @@ def fully_connected(x, weight, bias=None, num_hidden=None,
assert bias is not None, "Missing bias parameter"
return _api_internal.fully_connected(x, weight, bias, num_hidden,
no_bias, flatten)


# pylint: disable=too-many-arguments
@set_module('mxnet.ndarray.numpy_extension')
def pick(data, index, axis=-1, mode='clip', keepdims=False):
r"""Picks elements from an input array according to the input indices along the given axis.

Given an input array of shape ``(d0, d1)`` and indices of shape ``(i0,)``, the result will be
an output array of shape ``(i0,)`` with::

output[i] = input[i, indices[i]]

By default, if any index mentioned is too large, it is replaced by the index that addresses
the last element along an axis (the `clip` mode).

This function supports n-dimensional input and (n-1)-dimensional indices arrays.

Parameters
----------
data : NDArray
The input array
index : NDArray
The index array
axis : int or None, optional, default='-1'
int or None. The axis to picking the elements.
Negative values means indexing from right to left.
If is `None`, the elements in the index w.r.t the flattened input will be picked.
keepdims : boolean, optional, default=0
If true, the axis where we pick the elements is
left in the result as dimension with size one.
mode : {'clip', 'wrap'},optional, default='clip'
Specify how out-of-bound indices behave. Default is "clip".
"clip" means clip to the range. So, if all indices mentioned are too large,
they are replaced by the index that addresses the last element along an axis.
"wrap" means to wrap around.

out : NDArray, optional
The output NDArray to hold the result.

Returns
-------
out : NDArray or list of NDArrays
The output of this function.

Example
-------
>>> x = np.array([[1., 2.],[3., 4.],[5., 6.]])

picks elements with specified indices along axis 0

>>> npx.pick(x, np.array([0, 1]), 0)
array([1., 4.])

picks elements with specified indices along axis 1

>>> npx.pick(x, np.array([0, 1, 0]), 1)
array([1., 4., 5.])

picks elements with specified indices along axis 1 using 'wrap' mode
to place indicies that would normally be out of bounds

>>> npx.pick(x, np.array([2, -1, -2]), 1, mode='wrap')
array([1., 4., 5.])

picks elements with specified indices along axis 1 and dims are maintained

>>> npx.pick(x, np.array([[1.], [0.], [2.]]), 1, keepdims=True)
array([[2.],
[3.],
[6.]])
"""
return _api_internal.pick(data, index, axis, mode, keepdims)


# pylint: disable=too-many-arguments
@set_module('mxnet.ndarray.numpy_extension')
def convolution(data=None, weight=None, bias=None, kernel=None, stride=None, dilate=None,
pad=None, num_filter=1, num_group=1, workspace=1024, no_bias=False,
cudnn_tune=None, cudnn_off=False, layout=None):
r"""Compute *N*-D convolution on *(N+2)*-D input.

In the 2-D convolution, given input data with shape *(batch_size,
channel, height, width)*, the output is computed by

.. math::

out[n,i,:,:] = bias[i] + \sum_{j=0}^{channel} data[n,j,:,:] \star
weight[i,j,:,:]

where :math:`\star` is the 2-D cross-correlation operator.

For general 2-D convolution, the shapes are

- **data**: *(batch_size, channel, height, width)*
- **weight**: *(num_filter, channel, kernel[0], kernel[1])*
- **bias**: *(num_filter,)*
- **out**: *(batch_size, num_filter, out_height, out_width)*.

Define::

f(x,k,p,s,d) = floor((x+2*p-d*(k-1)-1)/s)+1

then we have::

out_height=f(height, kernel[0], pad[0], stride[0], dilate[0])
out_width=f(width, kernel[1], pad[1], stride[1], dilate[1])

If ``no_bias`` is set to be true, then the ``bias`` term is ignored.

The default data ``layout`` is *NCHW*, namely *(batch_size, channel, height,
width)*. We can choose other layouts such as *NWC*.

If ``num_group`` is larger than 1, denoted by *g*, then split the input ``data``
evenly into *g* parts along the channel axis, and also evenly split ``weight``
along the first dimension. Next compute the convolution on the *i*-th part of
the data with the *i*-th weight part. The output is obtained by concatenating all
the *g* results.

1-D convolution does not have *height* dimension but only *width* in space.

- **data**: *(batch_size, channel, width)*
- **weight**: *(num_filter, channel, kernel[0])*
- **bias**: *(num_filter,)*
- **out**: *(batch_size, num_filter, out_width)*.

3-D convolution adds an additional *depth* dimension besides *height* and
*width*. The shapes are

- **data**: *(batch_size, channel, depth, height, width)*
- **weight**: *(num_filter, channel, kernel[0], kernel[1], kernel[2])*
- **bias**: *(num_filter,)*
- **out**: *(batch_size, num_filter, out_depth, out_height, out_width)*.

Both ``weight`` and ``bias`` are learnable parameters.

There are other options to tune the performance.

- **cudnn_tune**: enable this option leads to higher startup time but may give
faster speed. Options are

- **off**: no tuning
- **limited_workspace**:run test and pick the fastest algorithm that doesn't
exceed workspace limit.
- **fastest**: pick the fastest algorithm and ignore workspace limit.
- **None** (default): the behavior is determined by environment variable
``MXNET_CUDNN_AUTOTUNE_DEFAULT``. 0 for off, 1 for limited workspace
(default), 2 for fastest.

- **workspace**: A large number leads to more (GPU) memory usage but may improve
the performance.

Parameters
----------
data : NDArray
Input data to the ConvolutionOp.
weight : NDArray
Weight matrix.
bias : NDArray
Bias parameter.
kernel : Shape(tuple), required
Convolution kernel size: (w,), (h, w) or (d, h, w)
stride : Shape(tuple), optional, default=[]
Convolution stride: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
dilate : Shape(tuple), optional, default=[]
Convolution dilate: (w,), (h, w) or (d, h, w). Defaults to 1 for each dimension.
pad : Shape(tuple), optional, default=[]
Zero pad for convolution: (w,), (h, w) or (d, h, w). Defaults to no padding.
num_filter : int (non-negative), required
Convolution filter(channel) number
num_group : int (non-negative), optional, default=1
Number of group partitions.
workspace : long (non-negative), optional, default=1024
Maximum temporary workspace allowed (MB) in convolution.This parameter has two usages.
When CUDNN is not used, it determines the effective batch size of the convolution kernel.
When CUDNN is used, it controls the maximum temporary storage used for tuning the best
CUDNN kernel when `limited_workspace` strategy is used.
no_bias : boolean, optional, default=0
Whether to disable bias parameter.
cudnn_tune : {None, 'fastest', 'limited_workspace', 'off'},optional, default='None'
Whether to pick convolution algo by running performance test.
cudnn_off : boolean, optional, default=0
Turn off cudnn for this layer.
layout : {None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'},optional, default='None'
Set layout for input, output and weight. Empty for
default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d.
NHWC and NDHWC are only supported on GPU.

Returns
-------
out : NDArray or list of NDArrays
The output of this function.
"""
assert data is not None and weight is not None and kernel is not None, \
"Missing input data, weight or kernel"
assert num_filter > 1, "Number of output filters should be greater than 1"
assert workspace > 0, "Maximum temporary workspace should be greater than 0"
if no_bias:
assert bias is None, "Using no bias"
return _api_internal.convolution(data, weight, kernel, stride, dilate, pad,
num_filter, num_group, workspace, no_bias,
cudnn_tune, cudnn_off, layout)
else:
assert bias is not None, "Using bias"
return _api_internal.convolution(data, weight, bias, kernel, stride, dilate, pad,
num_filter, num_group, workspace, no_bias,
cudnn_tune, cudnn_off, layout)


# pylint: disable=too-many-arguments
@set_module('mxnet.ndarray.numpy_extension')
def deconvolution(data=None, weight=None, bias=None, kernel=None, stride=None, dilate=None,
pad=None, adj=None, target_shape=None, num_filter=1, num_group=1,
workspace=512, no_bias=False, cudnn_tune=None,
cudnn_off=False, layout=None):
r"""Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of
the input tensor. This operation can be seen as the gradient of Convolution operation
with respect to its input. Convolution usually reduces the size of the input.
Transposed convolution works the other way, going from a smaller input
to a larger output while preserving the connectivity pattern.

Parameters
----------
data : NDArray
Input tensor to the deconvolution operation.
weight : NDArray
Weights representing the kernel.
bias : NDArray
Bias added to the result after the deconvolution operation.
kernel : Shape(tuple), required
Deconvolution kernel size: (w,), (h, w) or (d, h, w).
This is same as the kernel size used for the corresponding convolution
stride : Shape(tuple), optional, default=[]
The stride used for the corresponding convolution: (w,), (h, w) or (d, h, w).
Defaults to 1 for each dimension.
dilate : Shape(tuple), optional, default=[]
Dilation factor for each dimension of the input: (w,), (h, w) or (d, h, w).
Defaults to 1 for each dimension.
pad : Shape(tuple), optional, default=[]
The amount of implicit zero padding added during convolution for each dimension of
the input: (w,), (h, w) or (d, h, w). ``(kernel-1)/2`` is usually a good choice.
If `target_shape` is set, `pad` will be ignored and a padding that will generate
the target shape will be used. Defaults to no padding.
adj : Shape(tuple), optional, default=[]
Adjustment for output shape: (w,), (h, w) or (d, h, w).
If `target_shape` is set, `adj` will be ignored and computed accordingly.
target_shape : Shape(tuple), optional, default=[]
Shape of the output tensor: (w,), (h, w) or (d, h, w).
num_filter : int (non-negative), required
Number of output filters.
num_group : int (non-negative), optional, default=1
Number of groups partition.
workspace : long (non-negative), optional, default=512
Maximum temporary workspace allowed (MB) in deconvolution. This parameter has two usages.
When CUDNN is not used, it determines the effective batch size of the deconvolution kernel.
When CUDNN is used, it controls the maximum temporary storage used for tuning
the best CUDNN kernel when `limited_workspace` strategy is used.
no_bias : boolean, optional, default=1
Whether to disable bias parameter.
cudnn_tune : {None, 'fastest', 'limited_workspace', 'off'},optional, default='None'
Whether to pick convolution algorithm by running performance test.
cudnn_off : boolean, optional, default=0
Turn off cudnn for this layer.
layout : {None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'},optional, default='None'
Set layout for input, output and weight. Empty for
default layout, NCW for 1d, NCHW for 2d and NCDHW for 3d.
NHWC and NDHWC are only supported on GPU.

out : NDArray, optional
The output NDArray to hold the result.

Returns
-------
out : NDArray or list of NDArrays
The output of this function.
"""
assert data is not None and weight is not None and kernel is not None, \
"Missing input data, weight or kernel"
assert num_filter > 1, "Number of output filters should be greater than 1"
assert workspace > 0, "Maximum temporary workspace should be greater than 0"
if no_bias:
assert bias is None, "Using no bias"
return _api_internal.deconvolution(data, weight, kernel, stride, dilate, pad,
adj, target_shape, num_filter, num_group,
workspace, no_bias, cudnn_tune, cudnn_off, layout)
else:
assert bias is not None, "Using bias"
return _api_internal.deconvolution(data, weight, bias, kernel, stride, dilate, pad,
adj, target_shape, num_filter, num_group,
workspace, no_bias, cudnn_tune, cudnn_off, layout)
Loading