Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Allow stop of arange to be inferred from dims. (#12064)
Browse files Browse the repository at this point in the history
* Allow stop of arange to be inferred from dims.

Enabled via a flag.

* modify NDArray/Symbol to add infer_range param

* Add test for arange-with-inference.

* Add a comment to readme about JDK 8.

* Fix approx=.

Include a test of this fix as well.
  • Loading branch information
Taliesin Beynon authored and nswamy committed Aug 24, 2018
1 parent 5189495 commit 7bfe427
Show file tree
Hide file tree
Showing 12 changed files with 85 additions and 18 deletions.
4 changes: 3 additions & 1 deletion contrib/clojure-package/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ The jars from maven with the needed MXNet native binaries in it. On startup, the

### Build from MXNET Source

Checkout the latest sha from the main package
First, ensure you have JDK 8 on your system. Later versions may produce cryptic build errors mentioning `scala.reflect.internal.MissingRequirementError`.

Checkout the latest SHA from the main package:

`git clone --recursive https://github.com/apache/incubator-mxnet.git ~/mxnet`
`cd ~/mxnet`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
(NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype))
([start stop]
(arange start stop {})))

(defn slice
"Return a sliced NDArray that shares memory with current one."
([ndarray i]
Expand Down
12 changes: 11 additions & 1 deletion contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,20 @@
([start stop {:keys [step repeat dtype]
:or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
:as opts}]
(Symbol/arange (float start) ($/option (float stop)) step repeat nil dtype))
(Symbol/arange (float start) ($/option (float stop)) step repeat false nil dtype))
([start stop]
(arange start stop {})))

(defn arange-with-inference
"Behaves like arange operator, but infers the stop value from the output shape,
which must be known from the rest of the net."
([start {:keys [step repeat dtype]
:or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
:as opts}]
(Symbol/arange (float start) ($/option nil) step repeat true nil dtype))
([start]
(arange-with-inference start {})))

;;; manually defined because of a conflicting arity of 2 with the auto-gen
(defn min
([sym-name kwargs-map symbol-list kwargs-map-1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,17 @@
(is (= 0 (count (executor/grad-arrays exec))))
(is (approx= 1e-4 result (-> (executor/outputs exec) (first))))))

(deftest test-arange-with-inference
(let [arange (sym/arange-with-inference 0)
data (sym/variable "data")
added (sym/+ arange data)
result (range 0 4)
data-tmp (ndarray/zeros [4])
exec (sym/bind added (context/default-context) {"data" data-tmp})]
(executor/forward exec)
(is (= 0 (count (executor/grad-arrays exec))))
(is (approx= 1e-4 result (-> (executor/outputs exec) (first))))))

(deftest test-scalar-pow
(let [data (sym/variable "data")
shape-vec [1 1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
(if (and (number? x) (number? y))
(let [diff (Math/abs (- x y))]
(< diff tolerance))
(reduce (fn [x y] (and x y))
(map #(approx= tolerance %1 %2) x y))))
(and
(= (count x) (count y))
(reduce (fn [x y] (and x y))
(map #(approx= tolerance %1 %2) x y)))))

Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
[org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.symbol :as sym]
[org.apache.clojure-mxnet.test-util :as test-util]
[clojure.spec.alpha :as s])
(:import (org.apache.mxnet Shape NDArrayFuncReturn NDArray)
(scala.collection Map Set)
Expand Down Expand Up @@ -183,3 +184,10 @@
(deftest test-validate
(is (nil? (util/validate! string? "foo" "Not a string!")))
(is (thrown-with-msg? Exception #"Not a string!" (util/validate! ::x 1 "Not a string!"))))

(deftest test-approx=
(let [data1 [1 1 1 1]
data2 [1 1 1 1 9 9 9 9]
data3 [1 1 1 2]]
(is (not (test-util/approx= 1e-9 data1 data2)))
(is (test-util/approx= 2 data1 data3))))
4 changes: 2 additions & 2 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2475,7 +2475,7 @@ def moveaxis(tensor, source, destination):


# pylint: disable= no-member, protected-access, too-many-arguments, redefined-outer-name
def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, ctx=None, dtype=mx_real_t):
"""Returns evenly spaced values within a given interval.
Values are generated within the half-open interval [`start`, `stop`). In other
Expand Down Expand Up @@ -2519,7 +2519,7 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
if ctx is None:
ctx = current_context()
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
dtype=dtype, ctx=str(ctx))
infer_range=infer_range, dtype=dtype, ctx=str(ctx))
# pylint: enable= no-member, protected-access, too-many-arguments


Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2886,7 +2886,7 @@ def full(shape, val, dtype=None, **kwargs):
return _internal._full(shape=shape, dtype=dtype, value=float(val), **kwargs)

# pylint: disable=redefined-outer-name
def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):
def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None):
"""Returns evenly spaced values within a given interval.
Parameters
Expand All @@ -2911,7 +2911,7 @@ def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):
if dtype is None:
dtype = _numpy.float32
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
name=name, dtype=dtype)
infer_range=infer_range, name=name, dtype=dtype)

def histogram(a, bins=10, range=None, **kwargs):
"""Compute the histogram of the input data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,10 @@ object NDArray extends NDArrayBase {
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, ctx: Context = Context.defaultCtx,
dType: DType = Base.MX_REAL_TYPE): NDArray = {
val params = Map("start" -> start, "step" -> step,
"repeat" -> repeat, "ctx" -> ctx.toString, "dtype" -> dType.toString())
def arange(start: Float, stop: Option[Float], step: Float,
repeat: Int, ctx: Context, dType: DType): NDArray = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> false, "ctx" -> ctx.toString, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0)
}
Expand Down
25 changes: 22 additions & 3 deletions scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -955,9 +955,28 @@ object Symbol extends SymbolBase {
* @return Symbol The created Symbol.
*/
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = {
val params = Map("start" -> start, "step" -> step,
"repeat" -> repeat, "dtype" -> dType.toString())
repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = {
arange(start, stop, step, repeat, infer_range = false, name, dType)
}

/**
* Returns evenly spaced values within a given interval.
* stop value can be infered from the output shape,
* which must be known from the rest of the net.
* @param start Start of interval. The default start value is 0.
* @param stop End of interval.
* @param step Spacing between values. The default step size is 1.
* @param repeat Number of times to repeat each element. The default repeat count is 1.
* @param infer_range Infer the stop value from output shape
* @param ctx Device context. Default context is the current default context.
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arange(start: Float, stop: Option[Float], step: Float,
repeat: Int, infer_range: Boolean, name: String,
dType: DType): Symbol = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> infer_range, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
createSymbolGeneral("_arange", name, null, Array.empty[Symbol], fParams)
}
Expand Down
10 changes: 9 additions & 1 deletion src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ struct RangeParam : public dmlc::Parameter<RangeParam> {
dmlc::optional<double> stop;
double step;
int repeat;
bool infer_range;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(RangeParam) {
Expand All @@ -140,6 +141,10 @@ struct RangeParam : public dmlc::Parameter<RangeParam> {
.set_default(1)
.describe("The repeating time of all elements."
" E.g repeat=3, the element a will be repeated three times --> a, a, a.");
DMLC_DECLARE_FIELD(infer_range)
.set_default(false)
.describe("Whether to infer the stop position from the start, step, repeat, and output tensor"
"size.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
Expand Down Expand Up @@ -176,7 +181,7 @@ struct InitOpWithScalarParam : dmlc::Parameter<InitOpWithScalarParam> {
inline void RangeParamParser(nnvm::NodeAttrs* attrs) {
RangeParam param;
param.Init(attrs->dict);
if (!static_cast<bool>(param.stop)) {
if (!static_cast<bool>(param.infer_range) && !static_cast<bool>(param.stop)) {
param.stop = param.start;
param.start = 0;
}
Expand Down Expand Up @@ -471,6 +476,9 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs,
<< "Range does not support step=0, received " << param.step;
CHECK(param.repeat > 0)
<< "Range only supports repeat > 0, received " << param.repeat;
if (param.infer_range && !param.stop.has_value()) {
return false;
}
if (param.step > 0) {
CHECK(param.start < param.stop.value())
<< "Invalid range (start, stop, step) = "
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3646,10 +3646,18 @@ def test_arange():
nd_out = mx.nd.arange(*config, repeat=repeats, dtype=dtype)
assert_almost_equal(np_out, nd_out.asnumpy())

def test_arange_inferstop():
s = mx.sym.arange(start=0, stop=None, infer_range=True)
s = mx.sym.elemwise_add(s, mx.sym.zeros(shape=[5]))
exe = s.bind(ctx=mx.cpu(), args={})
exe.forward()
assert_almost_equal(exe.outputs[0].asnumpy(), np.array([0,1,2,3,4]))

test_basic_val_init(mx.sym.zeros, np.zeros, (3, 4), np.float32)
test_basic_val_init(mx.sym.ones, np.ones, 3, np.int32)
test_basic_val_init(mx.sym.ones, np.ones, (2, 2, 3), np.float16)
test_arange()
test_arange_inferstop()


@with_seed()
Expand Down

0 comments on commit 7bfe427

Please sign in to comment.