Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Port several np ops to master #15867

Merged
merged 2 commits into from
Aug 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 171 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from . import _internal as _npi
from ..ndarray import NDArray

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot', 'linspace']
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -97,6 +98,58 @@ def ones(shape, dtype=_np.float32, order='C', ctx=None):
return _npi.ones(shape=shape, ctx=ctx, dtype=dtype)


@set_module('mxnet.ndarray.numpy')
def arange(start, stop=None, step=1, dtype=None, ctx=None):
"""Return evenly spaced values within a given interval.

Values are generated within the half-open interval ``[start, stop)``
(in other words, the interval including `start` but excluding `stop`).
For integer arguments the function is equivalent to the Python built-in
`range` function, but returns an ndarray rather than a list.

Parameters
----------
start : number, optional
Start of interval. The interval includes this value. The default
start value is 0.
stop : number
End of interval. The interval does not include this value, except
in some cases where `step` is not an integer and floating point
round-off affects the length of `out`.
step : number, optional
Spacing between values. For any output `out`, this is the distance
between two adjacent values, ``out[i+1] - out[i]``. The default
step size is 1. If `step` is specified as a position argument,
`start` must also be given.
dtype : dtype
The type of the output array. The default is `float32`.

Returns
-------
arange : ndarray
Array of evenly spaced values.

For floating point arguments, the length of the result is
``ceil((stop - start)/step)``. Because of floating point overflow,
this rule may result in the last element of `out` being greater
than `stop`.
"""
if dtype is None:
dtype = 'float32'
if ctx is None:
ctx = current_context()
if stop is None:
stop = start
start = 0
if step is None:
step = 1
if start is None and stop is None:
raise ValueError('start and stop cannot be both None')
if step == 0:
raise ZeroDivisionError('step cannot be 0')
return _npi.arange(start=start, stop=stop, step=step, dtype=dtype, ctx=ctx)


#pylint: disable= too-many-arguments, no-member, protected-access
def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None):
""" Helper function for element-wise operation.
Expand Down Expand Up @@ -462,3 +515,120 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype), step
else:
return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype)


@set_module('mxnet.ndarray.numpy')
def expand_dims(a, axis):
"""Expand the shape of an array.

Insert a new axis that will appear at the `axis` position in the expanded

Parameters
----------
a : ndarray
Input array.
axis : int
Position in the expanded axes where the new axis is placed.

Returns
-------
res : ndarray
Output array. The number of dimensions is one greater than that of
the input array.
"""
return _npi.expand_dims(a, axis)


def _unary_func_helper(x, fn_array, fn_scalar, out=None, **kwargs):
"""Helper function for unary operators.

Parameters
----------
x : ndarray or scalar
Input of the unary operator.
fn_array : function
Function to be called if x is of ``ndarray`` type.
fn_scalar : function
Function to be called if x is a Python scalar.
out : ndarray
The buffer ndarray for storing the result of the unary function.

Returns
-------
out : mxnet.numpy.ndarray or scalar
Result array or scalar.
"""
if isinstance(x, numeric_types):
return fn_scalar(x, **kwargs)
elif isinstance(x, NDArray):
return fn_array(x, out=out, **kwargs)
else:
raise TypeError('type {} not supported'.format(str(type(x))))


@set_module('mxnet.ndarray.numpy')
def tile(A, reps):
r"""
Construct an array by repeating A the number of times given by reps.

If `reps` has length ``d``, the result will have dimension of
``max(d, A.ndim)``.

If ``A.ndim < d``, `A` is promoted to be d-dimensional by prepending new
axes. So a shape (3,) array is promoted to (1, 3) for 2-D replication,
or shape (1, 1, 3) for 3-D replication. If this is not the desired
behavior, promote `A` to d-dimensions manually before calling this
function.

If ``A.ndim > d``, `reps` is promoted to `A`.ndim by pre-pending 1's to it.
Thus for an `A` of shape (2, 3, 4, 5), a `reps` of (2, 2) is treated as
(1, 1, 2, 2).

Parameters
----------
A : ndarray or scalar
An input array or a scalar to repeat.
reps : a single integer or tuple of integers
The number of repetitions of `A` along each axis.

Returns
-------
c : ndarray
The tiled output array.

Examples
--------
>>> a = np.array([0, 1, 2])
>>> np.tile(a, 2)
array([0., 1., 2., 0., 1., 2.])
>>> np.tile(a, (2, 2))
array([[0., 1., 2., 0., 1., 2.],
[0., 1., 2., 0., 1., 2.]])
>>> np.tile(a, (2, 1, 2))
array([[[0., 1., 2., 0., 1., 2.]],
[[0., 1., 2., 0., 1., 2.]]])

>>> b = np.array([[1, 2], [3, 4]])
>>> np.tile(b, 2)
array([[1., 2., 1., 2.],
[3., 4., 3., 4.]])
>>> np.(b, (2, 1))
array([[1., 2.],
[3., 4.],
[1., 2.],
[3., 4.]])

>>> c = np.array([1,2,3,4])
>>> np.tile(c,(4,1))
array([[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.],
[1., 2., 3., 4.]])

Scalar as input:

>>> np.tile(2, 3)
array([2, 2, 2]) # repeating integer `2`

"""
return _unary_func_helper(A, _npi.tile, _np.tile, reps=reps)
4 changes: 4 additions & 0 deletions python/mxnet/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,9 @@
from . import _register
from ._op import * # pylint: disable=wildcard-import
from .utils import * # pylint: disable=wildcard-import
from .function_base import * # pylint: disable=wildcard-import
from .stride_tricks import * # pylint: disable=wildcard-import
from .io import * # pylint: disable=wildcard-import
from .arrayprint import * # pylint: disable=wildcard-import

__all__ = []
62 changes: 62 additions & 0 deletions python/mxnet/numpy/arrayprint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""ndarray print format controller."""

from __future__ import absolute_import, print_function

import numpy as onp
from ..util import set_module

__all__ = ['set_printoptions']


@set_module('mxnet.numpy')
def set_printoptions(precision=None, threshold=None, **kwarg):
"""
Set printing options.

These options determine the way floating point numbers and arrays are displayed.

Parameters
----------
precision : int or None, optional
Number of digits of precision for floating point output (default 8).
May be `None` if `floatmode` is not `fixed`, to print as many digits as
necessary to uniquely specify the value.
threshold : int, optional
Total number of array elements which trigger summarization
rather than full repr (default 1000).

Examples
--------
Floating point precision can be set:

>>> np.set_printoptions(precision=4)
>>> print(np.array([1.123456789]))
[ 1.1235]

Long arrays can be summarised:

>>> np.set_printoptions(threshold=5)
>>> print(np.arange(10))
[0. 1. 2. ... 7. 8. 9.]
"""
if kwarg:
raise NotImplementedError('mxnet.numpy.set_printoptions only supports parameters'
' precision and threshold for now.')
onp.set_printoptions(precision=precision, threshold=threshold, **kwarg)
115 changes: 115 additions & 0 deletions python/mxnet/numpy/function_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Numpy basic functions."""
from __future__ import absolute_import

from .stride_tricks import broadcast_arrays

__all__ = ['meshgrid']


def meshgrid(*xi, **kwargs):
"""
Return coordinate matrices from coordinate vectors.

Make N-D coordinate arrays for vectorized evaluations of
N-D scalar/vector fields over N-D grids, given
one-dimensional coordinate arrays x1, x2,..., xn.

Parameters
----------
x1, x2,..., xn : ndarrays
1-D arrays representing the coordinates of a grid.
indexing : {'xy', 'ij'}, optional
Cartesian ('xy', default) or matrix ('ij') indexing of output.
See Notes for more details.

sparse : bool, optional
If True a sparse grid is returned in order to conserve memory.
Default is False. Please note that `sparse=True` is currently
not supported.

copy : bool, optional
If False, a view into the original arrays are returned in order to
conserve memory. Default is True. Please note that `copy=False`
is currently not supported.

Returns
-------
X1, X2,..., XN : ndarray
For vectors `x1`, `x2`,..., 'xn' with lengths ``Ni=len(xi)`` ,
return ``(N1, N2, N3,...Nn)`` shaped arrays if indexing='ij'
or ``(N2, N1, N3,...Nn)`` shaped arrays if indexing='xy'
with the elements of `xi` repeated to fill the matrix along
the first dimension for `x1`, the second for `x2` and so on.

Notes
-----
This function supports both indexing conventions through the indexing
keyword argument. Giving the string 'ij' returns a meshgrid with
matrix indexing, while 'xy' returns a meshgrid with Cartesian indexing.
In the 2-D case with inputs of length M and N, the outputs are of shape
(N, M) for 'xy' indexing and (M, N) for 'ij' indexing. In the 3-D case
with inputs of length M, N and P, outputs are of shape (N, M, P) for
'xy' indexing and (M, N, P) for 'ij' indexing. The difference is
illustrated by the following code snippet::

xv, yv = np.meshgrid(x, y, sparse=False, indexing='ij')
for i in range(nx):
for j in range(ny):
# treat xv[i,j], yv[i,j]

xv, yv = np.meshgrid(x, y, sparse=False, indexing='xy')
for i in range(nx):
for j in range(ny):
# treat xv[j,i], yv[j,i]

In the 1-D and 0-D case, the indexing and sparse keywords have no effect.
"""
ndim = len(xi)

copy_ = kwargs.pop('copy', True)
if not copy_:
raise NotImplementedError('copy=False is not implemented')
sparse = kwargs.pop('sparse', False)
if sparse:
raise NotImplementedError('sparse=False is not implemented')
indexing = kwargs.pop('indexing', 'xy')

if kwargs:
raise TypeError("meshgrid() got an unexpected keyword argument '%s'"
% (list(kwargs)[0],))

if indexing not in ['xy', 'ij']:
raise ValueError(
"Valid values for `indexing` are 'xy' and 'ij'.")

s0 = (1,) * ndim
output = [x.reshape(s0[:i] + (-1,) + s0[i + 1:])
for i, x in enumerate(xi)]

if indexing == 'xy' and ndim > 1:
# switch first and second axis
output[0] = output[0].reshape(1, -1, *s0[2:])
output[1] = output[1].reshape(-1, 1, *s0[2:])

if not sparse:
# Return the full N-D matrix (not only the 1-D vector)
output = broadcast_arrays(*output)

return output
Loading