diff --git a/contrib/clojure-package/README.md b/contrib/clojure-package/README.md index 5e7356caf647..ea678ccf2db4 100644 --- a/contrib/clojure-package/README.md +++ b/contrib/clojure-package/README.md @@ -107,7 +107,9 @@ The jars from maven with the needed MXNet native binaries in it. On startup, the ### Build from MXNET Source -Checkout the latest sha from the main package +First, ensure you have JDK 8 on your system. Later versions may produce cryptic build errors mentioning `scala.reflect.internal.MissingRequirementError`. + +Checkout the latest SHA from the main package: `git clone --recursive https://github.com/apache/incubator-mxnet.git ~/mxnet` `cd ~/mxnet` diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj index e37a8bc8c98d..7ca4ede9733c 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj @@ -89,7 +89,7 @@ (NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype)) ([start stop] (arange start stop {}))) - + (defn slice "Return a sliced NDArray that shares memory with current one." ([ndarray i] diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj index 42ae034eb6d3..12135fb75cab 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj @@ -135,10 +135,20 @@ ([start stop {:keys [step repeat dtype] :or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE} :as opts}] - (Symbol/arange (float start) ($/option (float stop)) step repeat nil dtype)) + (Symbol/arange (float start) ($/option (float stop)) step repeat false nil dtype)) ([start stop] (arange start stop {}))) +(defn arange-with-inference + "Behaves like arange operator, but infers the stop value from the output shape, + which must be known from the rest of the net." + ([start {:keys [step repeat dtype] + :or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE} + :as opts}] + (Symbol/arange (float start) ($/option nil) step repeat true nil dtype)) + ([start] + (arange-with-inference start {}))) + ;;; manually defined because of a conflicting arity of 2 with the auto-gen (defn min ([sym-name kwargs-map symbol-list kwargs-map-1] diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj index a71a312e1ae6..1b4b2ea2fbe3 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj @@ -222,6 +222,17 @@ (is (= 0 (count (executor/grad-arrays exec)))) (is (approx= 1e-4 result (-> (executor/outputs exec) (first)))))) +(deftest test-arange-with-inference + (let [arange (sym/arange-with-inference 0) + data (sym/variable "data") + added (sym/+ arange data) + result (range 0 4) + data-tmp (ndarray/zeros [4]) + exec (sym/bind added (context/default-context) {"data" data-tmp})] + (executor/forward exec) + (is (= 0 (count (executor/grad-arrays exec)))) + (is (approx= 1e-4 result (-> (executor/outputs exec) (first)))))) + (deftest test-scalar-pow (let [data (sym/variable "data") shape-vec [1 1] diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj index dcdbea645796..ecd54ca72773 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj @@ -22,6 +22,8 @@ (if (and (number? x) (number? y)) (let [diff (Math/abs (- x y))] (< diff tolerance)) - (reduce (fn [x y] (and x y)) - (map #(approx= tolerance %1 %2) x y)))) + (and + (= (count x) (count y)) + (reduce (fn [x y] (and x y)) + (map #(approx= tolerance %1 %2) x y))))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj index 5551fab435f6..de3480827ba4 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj @@ -21,6 +21,7 @@ [org.apache.clojure-mxnet.util :as util] [org.apache.clojure-mxnet.ndarray :as ndarray] [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.test-util :as test-util] [clojure.spec.alpha :as s]) (:import (org.apache.mxnet Shape NDArrayFuncReturn NDArray) (scala.collection Map Set) @@ -183,3 +184,10 @@ (deftest test-validate (is (nil? (util/validate! string? "foo" "Not a string!"))) (is (thrown-with-msg? Exception #"Not a string!" (util/validate! ::x 1 "Not a string!")))) + +(deftest test-approx= + (let [data1 [1 1 1 1] + data2 [1 1 1 1 9 9 9 9] + data3 [1 1 1 2]] + (is (not (test-util/approx= 1e-9 data1 data2))) + (is (test-util/approx= 2 data1 data3)))) \ No newline at end of file diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 46b21a90d4c6..d6d619f30cab 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -2475,7 +2475,7 @@ def moveaxis(tensor, source, destination): # pylint: disable= no-member, protected-access, too-many-arguments, redefined-outer-name -def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t): +def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, ctx=None, dtype=mx_real_t): """Returns evenly spaced values within a given interval. Values are generated within the half-open interval [`start`, `stop`). In other @@ -2519,7 +2519,7 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t): if ctx is None: ctx = current_context() return _internal._arange(start=start, stop=stop, step=step, repeat=repeat, - dtype=dtype, ctx=str(ctx)) + infer_range=infer_range, dtype=dtype, ctx=str(ctx)) # pylint: enable= no-member, protected-access, too-many-arguments diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 5f6cbd6b6e14..da5533f36668 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -2886,7 +2886,7 @@ def full(shape, val, dtype=None, **kwargs): return _internal._full(shape=shape, dtype=dtype, value=float(val), **kwargs) # pylint: disable=redefined-outer-name -def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None): +def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None): """Returns evenly spaced values within a given interval. Parameters @@ -2911,7 +2911,7 @@ def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None): if dtype is None: dtype = _numpy.float32 return _internal._arange(start=start, stop=stop, step=step, repeat=repeat, - name=name, dtype=dtype) + infer_range=infer_range, name=name, dtype=dtype) def histogram(a, bins=10, range=None, **kwargs): """Compute the histogram of the input data. diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 548c30b73a14..8b5e1e010954 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -407,11 +407,10 @@ object NDArray extends NDArrayBase { * @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`. * @return NDArray of evenly spaced values in the specified range. */ - def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, - repeat: Int = 1, ctx: Context = Context.defaultCtx, - dType: DType = Base.MX_REAL_TYPE): NDArray = { - val params = Map("start" -> start, "step" -> step, - "repeat" -> repeat, "ctx" -> ctx.toString, "dtype" -> dType.toString()) + def arange(start: Float, stop: Option[Float], step: Float, + repeat: Int, ctx: Context, dType: DType): NDArray = { + val params = Map("start" -> start, "step" -> step, "repeat" -> repeat, + "infer_range" -> false, "ctx" -> ctx.toString, "dtype" -> dType.toString()) val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get) NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 194d3681523f..e3e1a320358e 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -955,9 +955,28 @@ object Symbol extends SymbolBase { * @return Symbol The created Symbol. */ def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, - repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = { - val params = Map("start" -> start, "step" -> step, - "repeat" -> repeat, "dtype" -> dType.toString()) + repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = { + arange(start, stop, step, repeat, infer_range = false, name, dType) + } + + /** + * Returns evenly spaced values within a given interval. + * stop value can be infered from the output shape, + * which must be known from the rest of the net. + * @param start Start of interval. The default start value is 0. + * @param stop End of interval. + * @param step Spacing between values. The default step size is 1. + * @param repeat Number of times to repeat each element. The default repeat count is 1. + * @param infer_range Infer the stop value from output shape + * @param ctx Device context. Default context is the current default context. + * @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`. + * @return NDArray of evenly spaced values in the specified range. + */ + def arange(start: Float, stop: Option[Float], step: Float, + repeat: Int, infer_range: Boolean, name: String, + dType: DType): Symbol = { + val params = Map("start" -> start, "step" -> step, "repeat" -> repeat, + "infer_range" -> infer_range, "dtype" -> dType.toString()) val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get) createSymbolGeneral("_arange", name, null, Array.empty[Symbol], fParams) } diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 4af3a40f42ab..304911a02a78 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -123,6 +123,7 @@ struct RangeParam : public dmlc::Parameter { dmlc::optional stop; double step; int repeat; + bool infer_range; std::string ctx; int dtype; DMLC_DECLARE_PARAMETER(RangeParam) { @@ -140,6 +141,10 @@ struct RangeParam : public dmlc::Parameter { .set_default(1) .describe("The repeating time of all elements." " E.g repeat=3, the element a will be repeated three times --> a, a, a."); + DMLC_DECLARE_FIELD(infer_range) + .set_default(false) + .describe("Whether to infer the stop position from the start, step, repeat, and output tensor" + "size."); DMLC_DECLARE_FIELD(ctx) .set_default("") .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." @@ -176,7 +181,7 @@ struct InitOpWithScalarParam : dmlc::Parameter { inline void RangeParamParser(nnvm::NodeAttrs* attrs) { RangeParam param; param.Init(attrs->dict); - if (!static_cast(param.stop)) { + if (!static_cast(param.infer_range) && !static_cast(param.stop)) { param.stop = param.start; param.start = 0; } @@ -471,6 +476,9 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs, << "Range does not support step=0, received " << param.step; CHECK(param.repeat > 0) << "Range only supports repeat > 0, received " << param.repeat; + if (param.infer_range && !param.stop.has_value()) { + return false; + } if (param.step > 0) { CHECK(param.start < param.stop.value()) << "Invalid range (start, stop, step) = " diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index fc6b81454229..fd60611add8c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3646,10 +3646,18 @@ def test_arange(): nd_out = mx.nd.arange(*config, repeat=repeats, dtype=dtype) assert_almost_equal(np_out, nd_out.asnumpy()) + def test_arange_inferstop(): + s = mx.sym.arange(start=0, stop=None, infer_range=True) + s = mx.sym.elemwise_add(s, mx.sym.zeros(shape=[5])) + exe = s.bind(ctx=mx.cpu(), args={}) + exe.forward() + assert_almost_equal(exe.outputs[0].asnumpy(), np.array([0,1,2,3,4])) + test_basic_val_init(mx.sym.zeros, np.zeros, (3, 4), np.float32) test_basic_val_init(mx.sym.ones, np.ones, 3, np.int32) test_basic_val_init(mx.sym.ones, np.ones, (2, 2, 3), np.float16) test_arange() + test_arange_inferstop() @with_seed()