From 56bb7fbe11001d73fab3d74155622364f76d899d Mon Sep 17 00:00:00 2001 From: Taliesin Beynon Date: Tue, 7 Aug 2018 17:05:45 +0200 Subject: [PATCH] Allow stop of arange to be inferred from dims. Enabled via a flag. --- .../src/org/apache/clojure_mxnet/ndarray.clj | 12 ++++++++- .../src/org/apache/clojure_mxnet/symbol.clj | 12 ++++++++- python/mxnet/ndarray/ndarray.py | 6 ++--- python/mxnet/symbol/symbol.py | 4 +-- .../main/scala/org/apache/mxnet/NDArray.scala | 27 ++++++++++++++++--- .../main/scala/org/apache/mxnet/Symbol.scala | 27 ++++++++++++++++--- src/operator/tensor/init_op.h | 10 ++++++- tests/python/unittest/test_operator.py | 8 ++++++ 8 files changed, 90 insertions(+), 16 deletions(-) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj index e37a8bc8c98d..e8b5e4d84952 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj @@ -86,7 +86,17 @@ ([start stop {:keys [step repeat ctx dtype] :or {step (float 1) repeat (int 1) ctx (mx-context/default-context) dtype base/MX_REAL_TYPE} :as opts}] - (NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype)) + (NDArray/arange (float start) ($/option (float stop)) step repeat false ctx dtype)) + ([start stop] + (arange start stop {}))) + +(defn arange-with-inference + "Behaves like arange operator, but infers the stop value from the output shape, + which must be known from the rest of the net." + ([start stop {:keys [step repeat ctx dtype] + :or {step (float 1) repeat (int 1) ctx (mx-context/default-context) dtype base/MX_REAL_TYPE} + :as opts}] + (NDArray/arange (float start) ($/option (float stop)) step repeat true ctx dtype)) ([start stop] (arange start stop {}))) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj index 42ae034eb6d3..aad741878a23 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj @@ -135,7 +135,17 @@ ([start stop {:keys [step repeat dtype] :or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE} :as opts}] - (Symbol/arange (float start) ($/option (float stop)) step repeat nil dtype)) + (Symbol/arange (float start) ($/option (float stop)) step repeat false nil dtype)) + ([start stop] + (arange start stop {}))) + +(defn arange-with-inference + "Behaves like arange operator, but infers the stop value from the output shape, + which must be known from the rest of the net." + ([start stop {:keys [step repeat dtype] + :or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE} + :as opts}] + (Symbol/arange (float start) ($/option (float stop)) step repeat true nil dtype)) ([start stop] (arange start stop {}))) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 46b21a90d4c6..10c64495fc0b 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -2475,7 +2475,7 @@ def moveaxis(tensor, source, destination): # pylint: disable= no-member, protected-access, too-many-arguments, redefined-outer-name -def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t): +def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, ctx=None, dtype=mx_real_t): """Returns evenly spaced values within a given interval. Values are generated within the half-open interval [`start`, `stop`). In other @@ -2518,8 +2518,8 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t): """ if ctx is None: ctx = current_context() - return _internal._arange(start=start, stop=stop, step=step, repeat=repeat, - dtype=dtype, ctx=str(ctx)) + return _internal._arange(start=start, stop=stop, step=step, infer_range=infer_range, + repeat=repeat, dtype=dtype, ctx=str(ctx)) # pylint: enable= no-member, protected-access, too-many-arguments diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 5f6cbd6b6e14..da5533f36668 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -2886,7 +2886,7 @@ def full(shape, val, dtype=None, **kwargs): return _internal._full(shape=shape, dtype=dtype, value=float(val), **kwargs) # pylint: disable=redefined-outer-name -def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None): +def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None): """Returns evenly spaced values within a given interval. Parameters @@ -2911,7 +2911,7 @@ def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None): if dtype is None: dtype = _numpy.float32 return _internal._arange(start=start, stop=stop, step=step, repeat=repeat, - name=name, dtype=dtype) + infer_range=infer_range, name=name, dtype=dtype) def histogram(a, bins=10, range=None, **kwargs): """Compute the histogram of the input data. diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 548c30b73a14..dded62c6e206 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -407,11 +407,30 @@ object NDArray extends NDArrayBase { * @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`. * @return NDArray of evenly spaced values in the specified range. */ - def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, - repeat: Int = 1, ctx: Context = Context.defaultCtx, + def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, repeat: Int = 1, + ctx: Context = Context.defaultCtx, dType: DType = Base.MX_REAL_TYPE): NDArray = { + val params = Map("start" -> start, "step" -> step, "repeat" -> repeat, + "infer_range" -> false, "ctx" -> ctx.toString, "dtype" -> dType.toString()) + val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get) + NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0) + } + + /** + * Behaves like arange operator, but infers the stop value from the output shape, + * which must be known from the rest of the net. + * @param start Start of interval. The default start value is 0. + * @param stop End of interval. + * @param step Spacing between values. The default step size is 1. + * @param repeat Number of times to repeat each element. The default repeat count is 1. + * @param ctx Device context. Default context is the current default context. + * @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`. + * @return NDArray of evenly spaced values in the specified range. + */ + def arangeWithInference(start: Float, stop: Option[Float] = None, step: Float = 1.0f, + repeat: Int = 1, ctx: Context = Context.defaultCtx, dType: DType = Base.MX_REAL_TYPE): NDArray = { - val params = Map("start" -> start, "step" -> step, - "repeat" -> repeat, "ctx" -> ctx.toString, "dtype" -> dType.toString()) + val params = Map("start" -> start, "step" -> step, "repeat" -> repeat, + "infer_range" -> true, "ctx" -> ctx.toString, "dtype" -> dType.toString()) val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get) NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 194d3681523f..880dcd626465 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -954,10 +954,29 @@ object Symbol extends SymbolBase { * @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`. * @return Symbol The created Symbol. */ - def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, - repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = { - val params = Map("start" -> start, "step" -> step, - "repeat" -> repeat, "dtype" -> dType.toString()) + def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, repeat: Int = 1, + name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = { + val params = Map("start" -> start, "step" -> step, "repeat" -> repeat, + "infer_range" -> false, "dtype" -> dType.toString()) + val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get) + createSymbolGeneral("_arange", name, null, Array.empty[Symbol], fParams) + } + +/** + * Behaves like arange operator, but infers the stop value from the output shape, + * which must be known from the rest of the net. + * @param start Start of interval. The default start value is 0. + * @param stop End of interval. + * @param step Spacing between values. The default step size is 1. + * @param repeat Number of times to repeat each element. The default repeat count is 1. + * @param ctx Device context. Default context is the current default context. + * @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`. + * @return NDArray of evenly spaced values in the specified range. + */ + def arangeWithInference(start: Float, stop: Option[Float] = None, step: Float = 1.0f, + repeat: Int = 1, dType: DType = Base.MX_REAL_TYPE): Symbol = { + val params = Map("start" -> start, "step" -> step, "repeat" -> repeat, + "infer_range" -> true, "dtype" -> dType.toString()) val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get) createSymbolGeneral("_arange", name, null, Array.empty[Symbol], fParams) } diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 4af3a40f42ab..304911a02a78 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -123,6 +123,7 @@ struct RangeParam : public dmlc::Parameter { dmlc::optional stop; double step; int repeat; + bool infer_range; std::string ctx; int dtype; DMLC_DECLARE_PARAMETER(RangeParam) { @@ -140,6 +141,10 @@ struct RangeParam : public dmlc::Parameter { .set_default(1) .describe("The repeating time of all elements." " E.g repeat=3, the element a will be repeated three times --> a, a, a."); + DMLC_DECLARE_FIELD(infer_range) + .set_default(false) + .describe("Whether to infer the stop position from the start, step, repeat, and output tensor" + "size."); DMLC_DECLARE_FIELD(ctx) .set_default("") .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." @@ -176,7 +181,7 @@ struct InitOpWithScalarParam : dmlc::Parameter { inline void RangeParamParser(nnvm::NodeAttrs* attrs) { RangeParam param; param.Init(attrs->dict); - if (!static_cast(param.stop)) { + if (!static_cast(param.infer_range) && !static_cast(param.stop)) { param.stop = param.start; param.start = 0; } @@ -471,6 +476,9 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs, << "Range does not support step=0, received " << param.step; CHECK(param.repeat > 0) << "Range only supports repeat > 0, received " << param.repeat; + if (param.infer_range && !param.stop.has_value()) { + return false; + } if (param.step > 0) { CHECK(param.start < param.stop.value()) << "Invalid range (start, stop, step) = " diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 90e85d123d59..e0f7219ea762 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3555,10 +3555,18 @@ def test_arange(): nd_out = mx.nd.arange(*config, repeat=repeats, dtype=dtype) assert_almost_equal(np_out, nd_out.asnumpy()) + def test_arange_inferstop(): + s = mx.sym.arange(start=0, stop=None, infer_range=True) + s = mx.sym.elemwise_add(s, mx.sym.zeros(shape=[5])) + exe = s.bind(ctx=mx.cpu(), args={}) + exe.forward() + assert_almost_equal(exe.outputs[0].asnumpy(), np.array([0,1,2,3,4])) + test_basic_val_init(mx.sym.zeros, np.zeros, (3, 4), np.float32) test_basic_val_init(mx.sym.ones, np.ones, 3, np.int32) test_basic_val_init(mx.sym.ones, np.ones, (2, 2, 3), np.float16) test_arange() + test_arange_inferstop() @with_seed()