diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index d912d38930a5..1c182731c78e 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -43,11 +43,12 @@ from ._internal import NDArrayBase __all__ = ["NDArray", "concatenate", "_DTYPE_NP_TO_MX", "_DTYPE_MX_TO_NP", "_GRAD_REQ_MAP", - "ones", "add", "arange", "eye", "divide", "equal", "full", "greater", "greater_equal", - "imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor", - "maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode", - "power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram", - "split_v2", "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack", "from_numpy"] + "ones", "add", "arange", "linspace", "eye", "divide", "equal", "full", "greater", + "greater_equal", "imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", + "logical_xor", "maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", + "onehot_encode", "power", "subtract", "true_divide", "waitall", "_new_empty_handle", + "histogram", "split_v2", "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack", + "from_numpy"] _STORAGE_TYPE_UNDEFINED = -1 _STORAGE_TYPE_DEFAULT = 0 @@ -2611,6 +2612,52 @@ def arange(start, stop=None, step=1.0, repeat=1, infer_range=None, ctx=None, dty # pylint: enable= no-member, protected-access, too-many-arguments +# pylint: disable= no-member, protected-access, too-many-arguments +def linspace(start, stop, num, endpoint=True, ctx=None, dtype=mx_real_t): + """Return evenly spaced numbers within a specified interval. + + Values are generated within the half-open interval [`start`, `stop`) or + closed interval [start, stop] depending on whether `endpoint` is True or + False. The function is similar to `numpy.linspace`, but returns an `NDArray`. + + Parameters + ---------- + start : number + Start of interval. + stop : number + End of interval, unless endpoint is set to False. In that case, + the sequence consists of all but the last of `num + 1` evenly spaced + samples, so that stop is excluded. Note that the step size changes + when endpoint is False. + num : number + Number of samples to generate. Must be non-negative. + endpoint : bool + If True, stop is the last sample. Otherwise, it is not included. + The default is True. + ctx : Context, optional + Device context. Default context is the current default context. + dtype : str or numpy.dtype, optional + The data type of the `NDArray`. The default datatype is `np.float32`. + + Returns + ------- + NDArray + `NDArray` of evenly spaced values in the specified range. + + Examples + -------- + >>> mx.nd.linspace(2.0, 3.0, 5).asnumpy() + array([ 2., 2.25., 2.5, 2.75, 3.], dtype=float32) + >>> mx.nd.linspace(2.0, 3.0, 5, endpoint=False).asnumpy() + array([ 2., 2.2., 2.4, 2.6, 2.8], dtype=float32) + """ + if ctx is None: + ctx = current_context() + return _internal._linspace(start=start, stop=stop, num=num, + endpoint=endpoint, dtype=dtype, ctx=str(ctx)) +# pylint: disable= no-member, protected-access, too-many-arguments + + #pylint: disable= too-many-arguments, no-member, protected-access def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None): """ Helper function for element-wise operation. diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 467d612700ec..7c800dfd0c88 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -47,8 +47,8 @@ from ._internal import SymbolBase, _set_symbol_class __all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json", - "pow", "power", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange", - "histogram", "split_v2"] + "pow", "power", "maximum", "minimum", "hypot", "eye", "zeros", + "ones", "full", "arange", "linspace", "histogram", "split_v2"] class Symbol(SymbolBase): @@ -3081,6 +3081,42 @@ def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, d return _internal._arange(start=start, stop=stop, step=step, repeat=repeat, infer_range=infer_range, name=name, dtype=dtype) +def linspace(start, stop, num, endpoint=True, name=None, dtype=None): + """Return evenly spaced numbers within a specified interval. + + Values are generated within the half-open interval [`start`, `stop`) or + closed interval [start, stop] depending on whether `endpoint` is True or + False. The function is similar to `numpy.linspace`, but returns a `Symbol`. + + Parameters + ---------- + start : number + Start of interval. + stop : number + End of interval, unless endpoint is set to False. In that case, + the sequence consists of all but the last of `num + 1` evenly spaced + samples, so that stop is excluded. Note that the step size changes + when endpoint is False. + num : number + Number of samples to generate. Must be non-negative. + endpoint : bool + If True, stop is the last sample. Otherwise, it is not included. + The default is True. + ctx : Context, optional + Device context. Default context is the current default context. + dtype : str or numpy.dtype, optional + The data type of the `NDArray`. The default datatype is `np.float32`. + + Returns + ------- + out : Symbol + The created Symbol + """ + if dtype is None: + dtype = _numpy.float32 + return _internal._linspace(start=start, stop=stop, num=num, endpoint=endpoint, + name=name, dtype=dtype) + def histogram(a, bins=10, range=None, **kwargs): """Compute the histogram of the input data. diff --git a/src/operator/tensor/init_op.cc b/src/operator/tensor/init_op.cc index 341748b50abe..a58498b85aa5 100644 --- a/src/operator/tensor/init_op.cc +++ b/src/operator/tensor/init_op.cc @@ -33,6 +33,7 @@ DMLC_REGISTER_PARAMETER(InitOpWithScalarParam); DMLC_REGISTER_PARAMETER(InitOpWithoutDTypeParam); DMLC_REGISTER_PARAMETER(RangeParam); DMLC_REGISTER_PARAMETER(EyeParam); +DMLC_REGISTER_PARAMETER(LinspaceParam); NNVM_REGISTER_OP(_zeros_without_dtype) .describe("fill target with zeros without default dtype") @@ -99,6 +100,16 @@ NNVM_REGISTER_OP(_arange) .set_attr("FCompute", RangeCompute) .add_arguments(RangeParam::__FIELDS__()); +NNVM_REGISTER_OP(_linspace) +.describe("Return evenly spaced numbers over a specified interval. Similar to Numpy") +.set_num_inputs(0) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", LinspaceShape) +.set_attr("FInferType", InitType) +.set_attr("FCompute", LinspaceCompute) +.add_arguments(RangeParam::__FIELDS__()); + NNVM_REGISTER_OP(zeros_like) MXNET_ADD_SPARSE_OP_ALIAS(zeros_like) .describe(R"code(Return an array of zeros with the same shape, type and storage type diff --git a/src/operator/tensor/init_op.cu b/src/operator/tensor/init_op.cu index 902b567516bd..5829ff2237e8 100644 --- a/src/operator/tensor/init_op.cu +++ b/src/operator/tensor/init_op.cu @@ -64,6 +64,9 @@ NNVM_REGISTER_OP(_full) NNVM_REGISTER_OP(_arange) .set_attr("FCompute", RangeCompute); +NNVM_REGISTER_OP(_linspace) +.set_attr("FCompute", LinspaceCompute); + NNVM_REGISTER_OP(zeros_like) .set_attr("FCompute", FillCompute) .set_attr("FComputeEx", FillComputeZerosEx); diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index b2e3830064ae..e4b090db933e 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -207,6 +207,33 @@ inline void RangeParamParser(nnvm::NodeAttrs* attrs) { attrs->parsed = std::move(param); } +struct LinspaceParam : public dmlc::Parameter { + double start; + double stop; + int num; + bool endpoint; + std::string ctx; + int dtype; + DMLC_DECLARE_PARAMETER(LinspaceParam) { + DMLC_DECLARE_FIELD(start) + .describe("The starting value of the sequence."); + DMLC_DECLARE_FIELD(stop) + .describe("The ending value of the sequence"); + DMLC_DECLARE_FIELD(num) + .describe("Number of samples to generate. Must be non-negative."); + DMLC_DECLARE_FIELD(endpoint) + .set_default(true) + .describe("If True, stop is the last sample. Otherwise, it is not included."); + DMLC_DECLARE_FIELD(ctx) + .set_default("") + .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." + "Only used for imperative calls."); + DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32) + MXNET_ADD_ALL_TYPES + .describe("Target data type."); + } +}; + template inline bool InitShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, @@ -519,6 +546,48 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs, return true; } +struct linspace_fwd { + template + MSHADOW_XINLINE static void Map(index_t i, double start, double stop, double step, + int req, DType* out) { + KERNEL_ASSIGN(out[i], req, static_cast(start + step * i)); + } +}; + +template +void LinspaceCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + const LinspaceParam& param = nnvm::get(attrs.parsed); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + int step_num = param.endpoint ? param.num - 1 : param.num; + double step = step_num > 0 ? (param.stop - param.start) / step_num : 0.0f; + Kernel::Launch(s, + outputs[0].Size(), + param.start, + param.stop, + step, + req[0], + outputs[0].dptr()); + }); +} + +inline bool LinspaceShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + const LinspaceParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 0U); + CHECK_EQ(out_attrs->size(), 1U); + CHECK_GE(param.num, 0) + << "Number of sequence should be non-negative, received " << param.num; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({static_cast(param.num)})); + return true; +} + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 374050668612..c62bd19453d9 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -726,6 +726,23 @@ def test_arange(): assert_almost_equal(pred, gt) +@with_seed() +def test_linspace(): + for i in range(5): + start = np.random.rand() * 100 + stop = np.random.rand() * 100 + num = np.random.randint(20) + gt = np.linspace(start, stop, num) + pred = mx.nd.linspace(start, stop, num).asnumpy() + assert_almost_equal(pred, gt) + gt = np.linspace(start, stop, num, endpoint=False) + pred = mx.nd.linspace(start, stop, num, endpoint=False).asnumpy() + assert_almost_equal(pred, gt) + gt = np.linspace(start, stop, num, dtype="int32") + pred = mx.nd.linspace(start, stop, num, dtype="int32").asnumpy() + assert_almost_equal(pred, gt) + + @with_seed() def test_order(): ctx = default_context()