Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add numpy linspace #14927

Merged
merged 5 commits into from
May 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@
from ._internal import NDArrayBase

__all__ = ["NDArray", "concatenate", "_DTYPE_NP_TO_MX", "_DTYPE_MX_TO_NP", "_GRAD_REQ_MAP",
"ones", "add", "arange", "eye", "divide", "equal", "full", "greater", "greater_equal",
"imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor",
"maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode",
"power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram",
"split_v2", "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack", "from_numpy"]
"ones", "add", "arange", "linspace", "eye", "divide", "equal", "full", "greater",
"greater_equal", "imdecode", "lesser", "lesser_equal", "logical_and", "logical_or",
"logical_xor", "maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal",
"onehot_encode", "power", "subtract", "true_divide", "waitall", "_new_empty_handle",
"histogram", "split_v2", "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack",
"from_numpy"]

_STORAGE_TYPE_UNDEFINED = -1
_STORAGE_TYPE_DEFAULT = 0
Expand Down Expand Up @@ -2611,6 +2612,52 @@ def arange(start, stop=None, step=1.0, repeat=1, infer_range=None, ctx=None, dty
# pylint: enable= no-member, protected-access, too-many-arguments


# pylint: disable= no-member, protected-access, too-many-arguments
def linspace(start, stop, num, endpoint=True, ctx=None, dtype=mx_real_t):
"""Return evenly spaced numbers within a specified interval.

Values are generated within the half-open interval [`start`, `stop`) or
closed interval [start, stop] depending on whether `endpoint` is True or
False. The function is similar to `numpy.linspace`, but returns an `NDArray`.

Parameters
----------
start : number
Start of interval.
stop : number
End of interval, unless endpoint is set to False. In that case,
the sequence consists of all but the last of `num + 1` evenly spaced
samples, so that stop is excluded. Note that the step size changes
when endpoint is False.
num : number
Number of samples to generate. Must be non-negative.
endpoint : bool
If True, stop is the last sample. Otherwise, it is not included.
The default is True.
ctx : Context, optional
Device context. Default context is the current default context.
dtype : str or numpy.dtype, optional
The data type of the `NDArray`. The default datatype is `np.float32`.

Returns
-------
NDArray
`NDArray` of evenly spaced values in the specified range.

Examples
--------
>>> mx.nd.linspace(2.0, 3.0, 5).asnumpy()
array([ 2., 2.25., 2.5, 2.75, 3.], dtype=float32)
>>> mx.nd.linspace(2.0, 3.0, 5, endpoint=False).asnumpy()
array([ 2., 2.2., 2.4, 2.6, 2.8], dtype=float32)
"""
if ctx is None:
ctx = current_context()
return _internal._linspace(start=start, stop=stop, num=num,
endpoint=endpoint, dtype=dtype, ctx=str(ctx))
# pylint: disable= no-member, protected-access, too-many-arguments


#pylint: disable= too-many-arguments, no-member, protected-access
def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None):
""" Helper function for element-wise operation.
Expand Down
40 changes: 38 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
from ._internal import SymbolBase, _set_symbol_class

__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
"pow", "power", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange",
"histogram", "split_v2"]
"pow", "power", "maximum", "minimum", "hypot", "eye", "zeros",
"ones", "full", "arange", "linspace", "histogram", "split_v2"]


class Symbol(SymbolBase):
Expand Down Expand Up @@ -3081,6 +3081,42 @@ def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, d
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
infer_range=infer_range, name=name, dtype=dtype)

def linspace(start, stop, num, endpoint=True, name=None, dtype=None):
"""Return evenly spaced numbers within a specified interval.

Values are generated within the half-open interval [`start`, `stop`) or
closed interval [start, stop] depending on whether `endpoint` is True or
False. The function is similar to `numpy.linspace`, but returns a `Symbol`.

Parameters
----------
start : number
Start of interval.
stop : number
End of interval, unless endpoint is set to False. In that case,
the sequence consists of all but the last of `num + 1` evenly spaced
samples, so that stop is excluded. Note that the step size changes
when endpoint is False.
num : number
Number of samples to generate. Must be non-negative.
endpoint : bool
If True, stop is the last sample. Otherwise, it is not included.
The default is True.
ctx : Context, optional
Device context. Default context is the current default context.
dtype : str or numpy.dtype, optional
The data type of the `NDArray`. The default datatype is `np.float32`.

Returns
-------
out : Symbol
The created Symbol
"""
if dtype is None:
dtype = _numpy.float32
return _internal._linspace(start=start, stop=stop, num=num, endpoint=endpoint,
name=name, dtype=dtype)

def histogram(a, bins=10, range=None, **kwargs):
"""Compute the histogram of the input data.

Expand Down
11 changes: 11 additions & 0 deletions src/operator/tensor/init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ DMLC_REGISTER_PARAMETER(InitOpWithScalarParam);
DMLC_REGISTER_PARAMETER(InitOpWithoutDTypeParam);
DMLC_REGISTER_PARAMETER(RangeParam);
DMLC_REGISTER_PARAMETER(EyeParam);
DMLC_REGISTER_PARAMETER(LinspaceParam);

NNVM_REGISTER_OP(_zeros_without_dtype)
.describe("fill target with zeros without default dtype")
Expand Down Expand Up @@ -99,6 +100,16 @@ NNVM_REGISTER_OP(_arange)
.set_attr<FCompute>("FCompute<cpu>", RangeCompute<cpu>)
.add_arguments(RangeParam::__FIELDS__());

NNVM_REGISTER_OP(_linspace)
.describe("Return evenly spaced numbers over a specified interval. Similar to Numpy")
.set_num_inputs(0)
.set_num_outputs(1)
.set_attr_parser(ParamParser<LinspaceParam>)
.set_attr<mxnet::FInferShape>("FInferShape", LinspaceShape)
.set_attr<nnvm::FInferType>("FInferType", InitType<LinspaceParam>)
.set_attr<FCompute>("FCompute<cpu>", LinspaceCompute<cpu>)
.add_arguments(RangeParam::__FIELDS__());

NNVM_REGISTER_OP(zeros_like)
MXNET_ADD_SPARSE_OP_ALIAS(zeros_like)
.describe(R"code(Return an array of zeros with the same shape, type and storage type
Expand Down
3 changes: 3 additions & 0 deletions src/operator/tensor/init_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ NNVM_REGISTER_OP(_full)
NNVM_REGISTER_OP(_arange)
.set_attr<FCompute>("FCompute<gpu>", RangeCompute<gpu>);

NNVM_REGISTER_OP(_linspace)
.set_attr<FCompute>("FCompute<gpu>", LinspaceCompute<gpu>);

NNVM_REGISTER_OP(zeros_like)
.set_attr<FCompute>("FCompute<gpu>", FillCompute<gpu, 0>)
.set_attr<FComputeEx>("FComputeEx<gpu>", FillComputeZerosEx<gpu>);
Expand Down
69 changes: 69 additions & 0 deletions src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,33 @@ inline void RangeParamParser(nnvm::NodeAttrs* attrs) {
attrs->parsed = std::move(param);
}

struct LinspaceParam : public dmlc::Parameter<LinspaceParam> {
double start;
double stop;
int num;
bool endpoint;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(LinspaceParam) {
DMLC_DECLARE_FIELD(start)
.describe("The starting value of the sequence.");
DMLC_DECLARE_FIELD(stop)
.describe("The ending value of the sequence");
DMLC_DECLARE_FIELD(num)
.describe("Number of samples to generate. Must be non-negative.");
DMLC_DECLARE_FIELD(endpoint)
.set_default(true)
.describe("If True, stop is the last sample. Otherwise, it is not included.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32)
MXNET_ADD_ALL_TYPES
.describe("Target data type.");
}
};

template<typename ParamType>
inline bool InitShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
Expand Down Expand Up @@ -519,6 +546,48 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs,
return true;
}

struct linspace_fwd {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, double start, double stop, double step,
int req, DType* out) {
KERNEL_ASSIGN(out[i], req, static_cast<DType>(start + step * i));
}
};

template<typename xpu>
void LinspaceCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
const LinspaceParam& param = nnvm::get<LinspaceParam>(attrs.parsed);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
int step_num = param.endpoint ? param.num - 1 : param.num;
double step = step_num > 0 ? (param.stop - param.start) / step_num : 0.0f;
Kernel<linspace_fwd, xpu>::Launch(s,
outputs[0].Size(),
param.start,
param.stop,
step,
req[0],
outputs[0].dptr<DType>());
});
}

inline bool LinspaceShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const LinspaceParam& param = nnvm::get<LinspaceParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 0U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK_GE(param.num, 0)
<< "Number of sequence should be non-negative, received " << param.num;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({static_cast<nnvm::dim_t>(param.num)}));
return true;
}

} // namespace op
} // namespace mxnet

Expand Down
17 changes: 17 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,23 @@ def test_arange():
assert_almost_equal(pred, gt)


@with_seed()
def test_linspace():
for i in range(5):
start = np.random.rand() * 100
stop = np.random.rand() * 100
num = np.random.randint(20)
gt = np.linspace(start, stop, num)
pred = mx.nd.linspace(start, stop, num).asnumpy()
assert_almost_equal(pred, gt)
gt = np.linspace(start, stop, num, endpoint=False)
pred = mx.nd.linspace(start, stop, num, endpoint=False).asnumpy()
assert_almost_equal(pred, gt)
gt = np.linspace(start, stop, num, dtype="int32")
pred = mx.nd.linspace(start, stop, num, dtype="int32").asnumpy()
assert_almost_equal(pred, gt)


@with_seed()
def test_order():
ctx = default_context()
Expand Down