Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Numpy compatible linspace #15256

Merged
merged 11 commits into from
Jun 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
from ...util import _sanity_check_params, set_module
from ...context import current_context
from . import _internal as _npi
from ..ndarray import NDArray

__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
'clip', 'split', 'swapaxes', 'expand_dims', 'tile']
'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -629,3 +630,63 @@ def tile(A, reps):
The tiled output array.
"""
return _npi.tile(A, reps)


@set_module('mxnet.ndarray.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, **kwargs): #pylint: disable=too-many-arguments
"""Return evenly spaced numbers over a specified interval.

Returns num evenly spaced samples, calculated over the interval [start, stop].
The endpoint of the interval can optionally be excluded.

Parameters
----------
start : array_like
The starting value of the sequence.
stop : array_like
The end value of the sequence, unless endpoint is set to False. In
that case, the sequence consists of all but the last of num + 1
evenly spaced samples, so that stop is excluded. Note that the step
size changes when endpoint is False.
num : int, optional
Number of samples to generate. Default is 50. Must be non-negative.
endpoint : bool, optional
If True, stop is the last sample. Otherwise, it is not included.
Default is True.
retstep: bool, optional
If True, return (samples, step), where step is the spacing between samples.
dtype: dtype, optional
The type of the output array. If dtype is not given, infer the data
type from the other input arguments.
axis : int, optional
The axis in the result to store the samples. Relevant only if start or
stop are array-like. By default (0), the samples will be along a new
axis inserted at the beginning. Use -1 to get an axis at the end.
Returns
-------
samples : ndarray
There are num equally spaced samples in the closed interval
`[start, stop]` or the half-open interval `[start, stop)`
(depending on whether endpoint is True or False).
step : float, optional
Only returned if retstep is True
Size of spacing between samples.

Notes
-----
This function currently does not support ``start`` and ``stop`` as ndarrays and
axis could only be 0 now.
"""
if isinstance(start, (list, _np.ndarray, NDArray)) or \
isinstance(stop, (list, _np.ndarray, NDArray)):
raise NotImplementedError('start and stop only support int')
if axis != 0:
raise NotImplementedError("the function only support axis 0")
ctx = kwargs.pop('ctx', current_context())
if ctx is None:
ctx = current_context()
if retstep:
step = (stop - start) / (num - 1)
return (_npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype), step)
else:
return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype)
45 changes: 44 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange',
'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
'clip', 'split', 'swapaxes', 'expand_dims', 'tile']
'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -1791,3 +1791,46 @@ def tile(A, reps):
The tiled output array.
"""
return _npi.tile(A, reps)


@set_module('mxnet.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, **kwargs):
"""Return evenly spaced numbers over a specified interval.

Returns num evenly spaced samples, calculated over the interval [start, stop].
The endpoint of the interval can optionally be excluded.

Parameters
----------
start : array_like
The starting value of the sequence.
stop : array_like
The end value of the sequence, unless endpoint is set to False. In
that case, the sequence consists of all but the last of num + 1
evenly spaced samples, so that stop is excluded. Note that the step
size changes when endpoint is False.
num : int, optional
Number of samples to generate. Default is 50. Must be non-negative.
endpoint : bool, optional
If True, stop is the last sample. Otherwise, it is not included.
Default is True.
retstep: bool, optional
If True, return (samples, step), where step is the spacing between samples.
dtype: dtype, optional
The type of the output array. If dtype is not given, infer the data
type from the other input arguments.
axis : int, optional
The axis in the result to store the samples. Relevant only if start or
stop are array-like. By default (0), the samples will be along a new
axis inserted at the beginning. Use -1 to get an axis at the end.
Returns
-------
samples : ndarray
There are num equally spaced samples in the closed interval
`[start, stop]` or the half-open interval `[start, stop)`
(depending on whether endpoint is True or False).
step : float, optional
Only returned if retstep is True
Size of spacing between samples.
"""
return _mx_nd_np.linspace(start, stop, num, endpoint, retstep, dtype, axis, **kwargs)
62 changes: 61 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax',
'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'split', 'swapaxes',
'expand_dims', 'tile']
'expand_dims', 'tile', 'linspace']


def _num_outputs(sym):
Expand Down Expand Up @@ -1307,4 +1307,64 @@ def tile(A, reps):
return _npi.tile(A, reps)


@set_module('mxnet.symbol.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, **kwargs): # pylint: disable=too-many-arguments
"""Return evenly spaced numbers over a specified interval.

Returns num evenly spaced samples, calculated over the interval [start, stop].
The endpoint of the interval can optionally be excluded.

Parameters
----------
start : array_like
The starting value of the sequence.
stop : array_like
The end value of the sequence, unless endpoint is set to False. In
that case, the sequence consists of all but the last of num + 1
evenly spaced samples, so that stop is excluded. Note that the step
size changes when endpoint is False.
num : int, optional
Number of samples to generate. Default is 50. Must be non-negative.
endpoint : bool, optional
If True, stop is the last sample. Otherwise, it is not included.
Default is True.
retstep: bool, optional
If True, return (samples, step), where step is the spacing between samples.
dtype: dtype, optional
The type of the output array. If dtype is not given, infer the data
type from the other input arguments.
axis : int, optional
The axis in the result to store the samples. Relevant only if start or
stop are array-like. By default (0), the samples will be along a new
axis inserted at the beginning. Use -1 to get an axis at the end.
Returns
-------
samples : ndarray
There are num equally spaced samples in the closed interval
`[start, stop]` or the half-open interval `[start, stop)`
(depending on whether endpoint is True or False).
step : float, optional
Only returned if retstep is True
Size of spacing between samples.

Notes
-----
This function currently does not support ``start`` and ``stop`` as ndarrays and
axis could only be 0 now.
"""
if isinstance(start, (list, _np.ndarray)) or \
isinstance(stop, (list, _np.ndarray)):
raise NotImplementedError('start and stop only support int')
if axis != 0:
raise NotImplementedError("the function only support axis 0")
ctx = kwargs.pop('ctx', current_context())
if ctx is None:
ctx = current_context()
if retstep:
step = (stop - start) / (num - 1)
return (_npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype), step)
else:
return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype)


_set_np_symbol_class(_Symbol)
1 change: 1 addition & 0 deletions src/operator/tensor/init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ NNVM_REGISTER_OP(_arange)
.add_arguments(RangeParam::__FIELDS__());

NNVM_REGISTER_OP(_linspace)
.add_alias("_npi_linspace")
.describe("Return evenly spaced numbers over a specified interval. Similar to Numpy")
.set_num_inputs(0)
.set_num_outputs(1)
Expand Down
73 changes: 72 additions & 1 deletion tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
import numpy as _np
import mxnet as mx
from mxnet import np, npx
from mxnet.base import MXNetError
from mxnet.gluon import HybridBlock
from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray
from mxnet.test_utils import check_numeric_gradient
from common import with_seed
from common import assertRaises, with_seed
import random


Expand Down Expand Up @@ -554,6 +555,76 @@ def hybrid_forward(self, F, x):
assert same(mx_out.asnumpy(), np_out)


@with_seed()
@npx.use_np_shape
def test_np_linspace():
configs = [
(0.0, 1.0, 10),
(-2, 4, 30),
(5.234324, 8.98324, 324),
(2, 10, 100)
]
exception_configs = [
(0, 10, -1),
(0, 1, 2.5)
]
dtypes = ['int32', 'float16', 'float32', 'float64', None]
for config in configs:
for dtype in dtypes:
for endpoint in [False, True]:
for retstep in [False, True]:
if isinstance(config, tuple):
mx_ret = np.linspace(*config, endpoint=endpoint, retstep=retstep, dtype=dtype)
np_ret = _np.linspace(*config, endpoint=endpoint, retstep=retstep, dtype=dtype)
else:
mx_ret = np.linspace(config, endpoint=endpoint, retstep=retstep, dtype=dtype)
np_ret = _np.linspace(config, endpoint=endpoint, retstep=retstep, dtype=dtype)
if retstep:
assert_almost_equal(mx_ret[0].asnumpy(), np_ret[0], atol=1e-3, rtol=1e-5)
same(mx_ret[1], np_ret[1])
else:
assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-3, rtol=1e-5)
# check for exception input
for config in exception_configs:
assertRaises(MXNetError, np.linspace, *config)
# check linspace equivalent to arange
for test_index in range(1000):
assert_almost_equal(mx.np.linspace(0, test_index, test_index + 1).asnumpy(), mx.np.arange(test_index + 1).asnumpy())
@npx.use_np
class TestLinspace(HybridBlock):
def __init__(self, start, stop, num=50, endpoint=None, retstep=False, dtype=None, axis=0):
super(TestLinspace, self).__init__()
self._start = start
self._stop = stop
self._num = num
self._endpoint = endpoint
self._retstep = retstep
self._dtype = dtype

def hybrid_forward(self, F, x):
if self._retstep:
raise ValueError("linspace didn't support retstep = True inside HybridBlock")
else:
return x + F.np.linspace(self._start, self._stop, self._num, \
self._endpoint, self._retstep, self._dtype)

for dtype in dtypes:
x = np.zeros(shape=(), dtype=dtype)
for config in configs:
for hybridize in [False, True]:
for endpoint in [False, True]:
if isinstance(config, tuple):
net = TestLinspace(*config, endpoint=endpoint, dtype=dtype)
np_out = _np.linspace(*config, endpoint=endpoint, dtype=dtype)
else:
net = TestLinspace(config, endpoint=endpoint, dtype=dtype)
np_out = _np.linspace(config, endpoint=endpoint, dtype=dtype)
if hybridize:
net.hybridize()
mx_out = net(x)
assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-3, rtol=1e-5)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: one more blank line here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


@with_seed()
@npx.use_np_shape
def test_np_argmax():
Expand Down