diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index a3b4a2730378..c92f359f5997 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -35,7 +35,7 @@ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique'] + 'unique', 'logspace'] @set_module('mxnet.ndarray.numpy') @@ -3263,3 +3263,86 @@ def hypot(x1, x2, out=None): [ 5., 5., 5.]]) """ return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out) + + +@set_module('mxnet.ndarray.numpy') +def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments + r"""Return numbers spaced evenly on a log scale. + + In linear space, the sequence starts at ``base ** start`` + (`base` to the power of `start`) and ends with ``base ** stop`` + (see `endpoint` below). + + Non-scalar `start` and `stop` are now supported. + + Parameters + ---------- + start : int or float + ``base ** start`` is the starting value of the sequence. + stop : int or float + ``base ** stop`` is the final value of the sequence, unless `endpoint` + is False. In that case, ``num + 1`` values are spaced over the + interval in log-space, of which all but the last (a sequence of + length `num`) are returned. + num : integer, optional + Number of samples to generate. Default is 50. + endpoint : boolean, optional + If true, `stop` is the last sample. Otherwise, it is not included. + Default is True. + base : float, optional + The base of the log space. The step size between the elements in + ``ln(samples) / ln(base)`` (or ``log_base(samples)``) is uniform. + Default is 10.0. + dtype : dtype + The type of the output array. If `dtype` is not given, infer the data + type from the other input arguments. + axis : int, optional + The axis in the result to store the samples. Relevant only if start + or stop are array-like. By default (0), the samples will be along a + new axis inserted at the beginning. Now, axis only support axis = 0. + ctx : Context, optional + An optional device context (default is the current default context). + + Returns + ------- + samples : ndarray + `num` samples, equally spaced on a log scale. + + See Also + -------- + arange : Similar to linspace, with the step size specified instead of the + number of samples. Note that, when used with a float endpoint, the + endpoint may or may not be included. + linspace : Similar to logspace, but with the samples uniformly distributed + in linear space, instead of log space. + + Notes + ----- + Logspace is equivalent to the code. Now wo only support axis = 0. + + >>> y = np.linspace(start, stop, num=num, endpoint=endpoint) + ... + >>> power(base, y).astype(dtype) + ... + + Examples + -------- + >>> np.logspace(2.0, 3.0, num=4) + array([ 100. , 215.44347, 464.15887, 1000. ]) + >>> np.logspace(2.0, 3.0, num=4, endpoint=False) + array([100. , 177.82794, 316.22775, 562.3413 ]) + >>> np.logspace(2.0, 3.0, num=4, base=2.0) + array([4. , 5.0396843, 6.349604 , 8. ]) + >>> np.logspace(2.0, 3.0, num=4, base=2.0, dtype=np.int32) + array([4, 5, 6, 8], dtype=int32) + >>> np.logspace(2.0, 3.0, num=4, ctx=npx.gpu(0)) + array([ 100. , 215.44347, 464.15887, 1000. ], ctx=gpu(0)) + """ + if isinstance(start, (list, tuple, _np.ndarray, NDArray)) or \ + isinstance(stop, (list, tuple, _np.ndarray, NDArray)): + raise NotImplementedError('start and stop only support int and float') + if axis != 0: + raise NotImplementedError("the function only support axis 0") + if ctx is None: + ctx = current_context() + return _npi.logspace(start=start, stop=stop, num=num, endpoint=endpoint, base=base, ctx=ctx, dtype=dtype) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 4972bdae7df6..9eb808bc693f 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -54,7 +54,7 @@ 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', - 'rad2deg', 'deg2rad', 'unique'] + 'rad2deg', 'deg2rad', 'unique', 'logspace'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -4792,3 +4792,79 @@ def hypot(x1, x2, out=None): [ 5., 5., 5.]]) """ return _mx_nd_np.hypot(x1, x2, out=out) + + +@set_module('mxnet.numpy') +def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0, ctx=None): + r"""Return numbers spaced evenly on a log scale. + + In linear space, the sequence starts at ``base ** start`` + (`base` to the power of `start`) and ends with ``base ** stop`` + (see `endpoint` below). + + Non-scalar `start` and `stop` are now supported. + + Parameters + ---------- + start : int or float + ``base ** start`` is the starting value of the sequence. + stop : int or float + ``base ** stop`` is the final value of the sequence, unless `endpoint` + is False. In that case, ``num + 1`` values are spaced over the + interval in log-space, of which all but the last (a sequence of + length `num`) are returned. + num : integer, optional + Number of samples to generate. Default is 50. + endpoint : boolean, optional + If true, `stop` is the last sample. Otherwise, it is not included. + Default is True. + base : float, optional + The base of the log space. The step size between the elements in + ``ln(samples) / ln(base)`` (or ``log_base(samples)``) is uniform. + Default is 10.0. + dtype : dtype + The type of the output array. If `dtype` is not given, infer the data + type from the other input arguments. + axis : int, optional + The axis in the result to store the samples. Relevant only if start + or stop are array-like. By default (0), the samples will be along a + new axis inserted at the beginning. Now, axis only support axis = 0. + ctx : Context, optional + An optional device context (default is the current default context). + + Returns + ------- + samples : ndarray + `num` samples, equally spaced on a log scale. + + See Also + -------- + arange : Similar to linspace, with the step size specified instead of the + number of samples. Note that, when used with a float endpoint, the + endpoint may or may not be included. + linspace : Similar to logspace, but with the samples uniformly distributed + in linear space, instead of log space. + + Notes + ----- + Logspace is equivalent to the code + + >>> y = np.linspace(start, stop, num=num, endpoint=endpoint) + ... + >>> power(base, y).astype(dtype) + ... + + Examples + -------- + >>> np.logspace(2.0, 3.0, num=4) + array([ 100. , 215.44347, 464.15887, 1000. ]) + >>> np.logspace(2.0, 3.0, num=4, endpoint=False) + array([100. , 177.82794, 316.22775, 562.3413 ]) + >>> np.logspace(2.0, 3.0, num=4, base=2.0) + array([4. , 5.0396843, 6.349604 , 8. ]) + >>> np.logspace(2.0, 3.0, num=4, base=2.0, dtype=np.int32) + array([4, 5, 6, 8], dtype=int32) + >>> np.logspace(2.0, 3.0, num=4, ctx=npx.gpu(0)) + array([ 100. , 215.44347, 464.15887, 1000. ], ctx=gpu(0)) + """ + return _mx_nd_np.logspace(start, stop, num, endpoint, base, dtype, axis, ctx=ctx) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 57b18ecaf547..233cdc6900f4 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -37,7 +37,7 @@ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique'] + 'unique', 'logspace'] def _num_outputs(sym): @@ -3394,4 +3394,87 @@ def unique(ar, return_index=False, return_inverse=False, return_counts=False, ax return _npi.unique(ar, return_index, return_inverse, return_counts, axis) +@set_module('mxnet.symbol.numpy') +def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments + r"""Return numbers spaced evenly on a log scale. + + In linear space, the sequence starts at ``base ** start`` + (`base` to the power of `start`) and ends with ``base ** stop`` + (see `endpoint` below). + + Non-scalar `start` and `stop` are now supported. + + Parameters + ---------- + start : scalar + ``base ** start`` is the starting value of the sequence. + stop : scalar + ``base ** stop`` is the final value of the sequence, unless `endpoint` + is False. In that case, ``num + 1`` values are spaced over the + interval in log-space, of which all but the last (a sequence of + length `num`) are returned. + num : scalar, optional + Number of samples to generate. Default is 50. + endpoint : boolean, optional + If true, `stop` is the last sample. Otherwise, it is not included. + Default is True. + base : scalar, optional + The base of the log space. The step size between the elements in + ``ln(samples) / ln(base)`` (or ``log_base(samples)``) is uniform. + Default is 10.0. + dtype : dtype + The type of the output array. If `dtype` is not given, infer the data + type from the other input arguments. + axis : scalar, optional + The axis in the result to store the samples. Relevant only if start + or stop are array-like. By default (0), the samples will be along a + new axis inserted at the beginning. Now, axis only support axis = 0. + ctx : Context, optional + An optional device context (default is the current default context). + + Returns + ------- + samples : _Symbol + `num` samples, equally spaced on a log scale. + + See Also + -------- + arange : Similar to linspace, with the step size specified instead of the + number of samples. Note that, when used with a float endpoint, the + endpoint may or may not be included. + linspace : Similar to logspace, but with the samples uniformly distributed + in linear space, instead of log space. + + Notes + ----- + Logspace is equivalent to the code + + >>> y = np.linspace(start, stop, num=num, endpoint=endpoint) + ... + >>> power(base, y).astype(dtype) + ... + + Examples + -------- + >>> np.logspace(2.0, 3.0, num=4) + array([ 100. , 215.44347, 464.15887, 1000. ]) + >>> np.logspace(2.0, 3.0, num=4, endpoint=False) + array([100. , 177.82794, 316.22775, 562.3413 ]) + >>> np.logspace(2.0, 3.0, num=4, base=2.0) + array([4. , 5.0396843, 6.349604 , 8. ]) + >>> np.logspace(2.0, 3.0, num=4, base=2.0, dtype=np.int32) + array([4, 5, 6, 8], dtype=int32) + >>> np.logspace(2.0, 3.0, num=4, ctx=npx.gpu(0)) + array([ 100. , 215.44347, 464.15887, 1000. ], ctx=gpu(0)) + """ + if isinstance(start, (list, _np.ndarray)) or \ + isinstance(stop, (list, _np.ndarray)): + raise NotImplementedError('start and stop only support int') + if axis != 0: + raise NotImplementedError("the function only support axis 0") + if ctx is None: + ctx = current_context() + return _npi.logspace(start=start, stop=stop, num=num, endpoint=endpoint, base=base, ctx=ctx, dtype=dtype) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_init_op.cc b/src/operator/numpy/np_init_op.cc index 4f031bdaa050..8e01b8e8c8ea 100644 --- a/src/operator/numpy/np_init_op.cc +++ b/src/operator/numpy/np_init_op.cc @@ -30,6 +30,7 @@ namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(IndicesOpParam); +DMLC_REGISTER_PARAMETER(LogspaceParam); inline bool NumpyIndicesShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_shapes, @@ -51,6 +52,18 @@ inline bool NumpyIndicesShape(const nnvm::NodeAttrs& attrs, return shape_is_known(out_shapes->at(0)); } +inline bool LogspaceShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + const LogspaceParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 0U); + CHECK_EQ(out_attrs->size(), 1U); + CHECK_GE(param.num, 0) + << "Number of sequence should be non-negative, received " << param.num; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({static_cast(param.num)})); + return true; +} + NNVM_REGISTER_OP(_npi_zeros) .set_num_inputs(0) .set_num_outputs(1) @@ -143,5 +156,15 @@ NNVM_REGISTER_OP(_npi_indices) .set_attr("FCompute", IndicesCompute) .add_arguments(IndicesOpParam::__FIELDS__()); +NNVM_REGISTER_OP(_npi_logspace) +.describe("Return numbers spaced evenly on a log scale.") +.set_num_inputs(0) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", LogspaceShape) +.set_attr("FInferType", InitType) +.set_attr("FCompute", LogspaceCompute) +.add_arguments(LogspaceParam::__FIELDS__()); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_init_op.cu index 49f1051735d8..d0291676446c 100644 --- a/src/operator/numpy/np_init_op.cu +++ b/src/operator/numpy/np_init_op.cu @@ -47,5 +47,8 @@ NNVM_REGISTER_OP(_npi_arange) NNVM_REGISTER_OP(_npi_indices) .set_attr("FCompute", IndicesCompute); +NNVM_REGISTER_OP(_npi_logspace) +.set_attr("FCompute", LogspaceCompute); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_init_op.h b/src/operator/numpy/np_init_op.h index 5c41820b57f8..913bfc3951dc 100644 --- a/src/operator/numpy/np_init_op.h +++ b/src/operator/numpy/np_init_op.h @@ -25,6 +25,7 @@ #ifndef MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_ #define MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_ +#include #include #include #include "../tensor/init_op.h" @@ -101,6 +102,65 @@ void IndicesCompute(const nnvm::NodeAttrs& attrs, } } +struct LogspaceParam : public dmlc::Parameter { + double start; + double stop; + int num; + bool endpoint; + double base; + std::string ctx; + int dtype; + DMLC_DECLARE_PARAMETER(LogspaceParam) { + DMLC_DECLARE_FIELD(start) + .describe("The starting value of the sequence."); + DMLC_DECLARE_FIELD(stop) + .describe("The ending value of the sequence"); + DMLC_DECLARE_FIELD(num) + .describe("Number of samples to generate. Must be non-negative."); + DMLC_DECLARE_FIELD(endpoint) + .set_default(true) + .describe("If True, stop is the last sample. Otherwise, it is not included."); + DMLC_DECLARE_FIELD(base) + .set_default(10.0) + .describe("The base of the log space. The step size between the elements in " + "ln(samples) / ln(base) (or log_base(samples)) is uniform."); + DMLC_DECLARE_FIELD(ctx) + .set_default("") + .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." + "Only used for imperative calls."); + DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32) + MXNET_ADD_ALL_TYPES + .describe("Target data type."); + } +}; + +struct logspace_fwd { + template + MSHADOW_XINLINE static void Map(index_t i, double start, double stop, double base, + double step, int req, DType* out) { + KERNEL_ASSIGN(out[i], req, + static_cast(math::pow(base, static_cast(start + step * i)))); + } +}; + +template +void LogspaceCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + const LogspaceParam& param = nnvm::get(attrs.parsed); + if (param.num == 0) return; + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + int step_num = param.endpoint ? param.num - 1 : param.num; + double step = step_num > 0 ? (param.stop - param.start) / step_num : 0.0f; + Kernel::Launch(s, outputs[0].Size(), param.start, param.stop, param.base, + step, req[0], outputs[0].dptr()); + }); +} + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index d2dc6ab269dc..67bba8d08ee1 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2600,6 +2600,58 @@ def hybrid_forward(self, F, a): assert_almost_equal(mx_out[i].asnumpy(), np_out[i], rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_logspace(): + class TestLogspace(HybridBlock): + def __init__(self, start, stop, num=50, endpoint=None, base=50.0, dtype=None, axis=0): + super(TestLogspace, self).__init__() + self._start = start + self._stop = stop + self._num = num + self._endpoint = endpoint + self._base = base + self._dtype = dtype + self.axis = axis + + def hybrid_forward(self, F, x): + return x + F.np.logspace(self._start, self._stop, self._num, self._endpoint, self._base, self._dtype, self.axis) + + configs = [ + (0.0, 1.0, 20), + (2, 8, 0), + (22, 11, 1), + (2.22, 9.99, 11), + (4.99999, 12.11111111, 111) + ] + base_configs = [0, 1, 5, 8, 10, 33] + dtypes = ['float32', 'float64', None] + + for config in configs: + for dtype in dtypes: + for endpoint in [False, True]: + for hybridize in [False, True]: + for base in base_configs: + x = np.zeros(shape=(), dtype=dtype) + net = TestLogspace(*config, endpoint=endpoint, base=base, dtype=dtype) + np_out = _np.logspace(*config, endpoint=endpoint, base=base, dtype=dtype) + if hybridize: + net.hybridize() + mx_out = net(x) + assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-3, rtol=1e-5) + if dtype is not None: + assert mx_out.dtype == np_out.dtype + + # Test imperative once again + mx_ret = np.logspace(*config, endpoint=endpoint, base=base, dtype=dtype) + np_ret = _np.logspace(*config, endpoint=endpoint, base=base, dtype=dtype) + + assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-3, rtol=1e-5) + assert mx_ret.dtype == np_ret.dtype + if dtype is not None: + assert mx_out.dtype == np_out.dtype + + if __name__ == '__main__': import nose nose.runmodule()