Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Allow stop of arange to be inferred from dims.
Browse files Browse the repository at this point in the history
Enabled via a flag.
  • Loading branch information
Taliesin Beynon committed Aug 18, 2018
1 parent 9dd5edd commit 56bb7fb
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 16 deletions.
12 changes: 11 additions & 1 deletion contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,17 @@
([start stop {:keys [step repeat ctx dtype]
:or {step (float 1) repeat (int 1) ctx (mx-context/default-context) dtype base/MX_REAL_TYPE}
:as opts}]
(NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype))
(NDArray/arange (float start) ($/option (float stop)) step repeat false ctx dtype))
([start stop]
(arange start stop {})))

(defn arange-with-inference
"Behaves like arange operator, but infers the stop value from the output shape,
which must be known from the rest of the net."
([start stop {:keys [step repeat ctx dtype]
:or {step (float 1) repeat (int 1) ctx (mx-context/default-context) dtype base/MX_REAL_TYPE}
:as opts}]
(NDArray/arange (float start) ($/option (float stop)) step repeat true ctx dtype))
([start stop]
(arange start stop {})))

Expand Down
12 changes: 11 additions & 1 deletion contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,17 @@
([start stop {:keys [step repeat dtype]
:or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
:as opts}]
(Symbol/arange (float start) ($/option (float stop)) step repeat nil dtype))
(Symbol/arange (float start) ($/option (float stop)) step repeat false nil dtype))
([start stop]
(arange start stop {})))

(defn arange-with-inference
"Behaves like arange operator, but infers the stop value from the output shape,
which must be known from the rest of the net."
([start stop {:keys [step repeat dtype]
:or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
:as opts}]
(Symbol/arange (float start) ($/option (float stop)) step repeat true nil dtype))
([start stop]
(arange start stop {})))

Expand Down
6 changes: 3 additions & 3 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2475,7 +2475,7 @@ def moveaxis(tensor, source, destination):


# pylint: disable= no-member, protected-access, too-many-arguments, redefined-outer-name
def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, ctx=None, dtype=mx_real_t):
"""Returns evenly spaced values within a given interval.
Values are generated within the half-open interval [`start`, `stop`). In other
Expand Down Expand Up @@ -2518,8 +2518,8 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
"""
if ctx is None:
ctx = current_context()
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
dtype=dtype, ctx=str(ctx))
return _internal._arange(start=start, stop=stop, step=step, infer_range=infer_range,
repeat=repeat, dtype=dtype, ctx=str(ctx))
# pylint: enable= no-member, protected-access, too-many-arguments


Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2886,7 +2886,7 @@ def full(shape, val, dtype=None, **kwargs):
return _internal._full(shape=shape, dtype=dtype, value=float(val), **kwargs)

# pylint: disable=redefined-outer-name
def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):
def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None):
"""Returns evenly spaced values within a given interval.
Parameters
Expand All @@ -2911,7 +2911,7 @@ def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):
if dtype is None:
dtype = _numpy.float32
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
name=name, dtype=dtype)
infer_range=infer_range, name=name, dtype=dtype)

def histogram(a, bins=10, range=None, **kwargs):
"""Compute the histogram of the input data.
Expand Down
27 changes: 23 additions & 4 deletions scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,30 @@ object NDArray extends NDArrayBase {
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, ctx: Context = Context.defaultCtx,
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, repeat: Int = 1,
ctx: Context = Context.defaultCtx, dType: DType = Base.MX_REAL_TYPE): NDArray = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> false, "ctx" -> ctx.toString, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0)
}

/**
* Behaves like arange operator, but infers the stop value from the output shape,
* which must be known from the rest of the net.
* @param start Start of interval. The default start value is 0.
* @param stop End of interval.
* @param step Spacing between values. The default step size is 1.
* @param repeat Number of times to repeat each element. The default repeat count is 1.
* @param ctx Device context. Default context is the current default context.
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arangeWithInference(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, ctx: Context = Context.defaultCtx,
dType: DType = Base.MX_REAL_TYPE): NDArray = {
val params = Map("start" -> start, "step" -> step,
"repeat" -> repeat, "ctx" -> ctx.toString, "dtype" -> dType.toString())
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> true, "ctx" -> ctx.toString, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0)
}
Expand Down
27 changes: 23 additions & 4 deletions scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -954,10 +954,29 @@ object Symbol extends SymbolBase {
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return Symbol The created Symbol.
*/
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = {
val params = Map("start" -> start, "step" -> step,
"repeat" -> repeat, "dtype" -> dType.toString())
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, repeat: Int = 1,
name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> false, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
createSymbolGeneral("_arange", name, null, Array.empty[Symbol], fParams)
}

/**
* Behaves like arange operator, but infers the stop value from the output shape,
* which must be known from the rest of the net.
* @param start Start of interval. The default start value is 0.
* @param stop End of interval.
* @param step Spacing between values. The default step size is 1.
* @param repeat Number of times to repeat each element. The default repeat count is 1.
* @param ctx Device context. Default context is the current default context.
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arangeWithInference(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, dType: DType = Base.MX_REAL_TYPE): Symbol = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> true, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
createSymbolGeneral("_arange", name, null, Array.empty[Symbol], fParams)
}
Expand Down
10 changes: 9 additions & 1 deletion src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ struct RangeParam : public dmlc::Parameter<RangeParam> {
dmlc::optional<double> stop;
double step;
int repeat;
bool infer_range;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(RangeParam) {
Expand All @@ -140,6 +141,10 @@ struct RangeParam : public dmlc::Parameter<RangeParam> {
.set_default(1)
.describe("The repeating time of all elements."
" E.g repeat=3, the element a will be repeated three times --> a, a, a.");
DMLC_DECLARE_FIELD(infer_range)
.set_default(false)
.describe("Whether to infer the stop position from the start, step, repeat, and output tensor"
"size.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
Expand Down Expand Up @@ -176,7 +181,7 @@ struct InitOpWithScalarParam : dmlc::Parameter<InitOpWithScalarParam> {
inline void RangeParamParser(nnvm::NodeAttrs* attrs) {
RangeParam param;
param.Init(attrs->dict);
if (!static_cast<bool>(param.stop)) {
if (!static_cast<bool>(param.infer_range) && !static_cast<bool>(param.stop)) {
param.stop = param.start;
param.start = 0;
}
Expand Down Expand Up @@ -471,6 +476,9 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs,
<< "Range does not support step=0, received " << param.step;
CHECK(param.repeat > 0)
<< "Range only supports repeat > 0, received " << param.repeat;
if (param.infer_range && !param.stop.has_value()) {
return false;
}
if (param.step > 0) {
CHECK(param.start < param.stop.value())
<< "Invalid range (start, stop, step) = "
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3555,10 +3555,18 @@ def test_arange():
nd_out = mx.nd.arange(*config, repeat=repeats, dtype=dtype)
assert_almost_equal(np_out, nd_out.asnumpy())

def test_arange_inferstop():
s = mx.sym.arange(start=0, stop=None, infer_range=True)
s = mx.sym.elemwise_add(s, mx.sym.zeros(shape=[5]))
exe = s.bind(ctx=mx.cpu(), args={})
exe.forward()
assert_almost_equal(exe.outputs[0].asnumpy(), np.array([0,1,2,3,4]))

test_basic_val_init(mx.sym.zeros, np.zeros, (3, 4), np.float32)
test_basic_val_init(mx.sym.ones, np.ones, 3, np.int32)
test_basic_val_init(mx.sym.ones, np.ones, (2, 2, 3), np.float16)
test_arange()
test_arange_inferstop()


@with_seed()
Expand Down

0 comments on commit 56bb7fb

Please sign in to comment.