Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Allow stop of arange to be inferred from dims. #12064

Merged
merged 5 commits into from
Aug 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion contrib/clojure-package/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ The jars from maven with the needed MXNet native binaries in it. On startup, the

### Build from MXNET Source

Checkout the latest sha from the main package
First, ensure you have JDK 8 on your system. Later versions may produce cryptic build errors mentioning `scala.reflect.internal.MissingRequirementError`.

Checkout the latest SHA from the main package:

`git clone --recursive https://github.com/apache/incubator-mxnet.git ~/mxnet`
`cd ~/mxnet`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
(NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype))
([start stop]
(arange start stop {})))

(defn slice
"Return a sliced NDArray that shares memory with current one."
([ndarray i]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,20 @@
([start stop {:keys [step repeat dtype]
:or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
:as opts}]
(Symbol/arange (float start) ($/option (float stop)) step repeat nil dtype))
(Symbol/arange (float start) ($/option (float stop)) step repeat false nil dtype))
([start stop]
(arange start stop {})))

(defn arange-with-inference
"Behaves like arange operator, but infers the stop value from the output shape,
which must be known from the rest of the net."
([start {:keys [step repeat dtype]
:or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
:as opts}]
(Symbol/arange (float start) ($/option nil) step repeat true nil dtype))
([start]
(arange-with-inference start {})))

;;; manually defined because of a conflicting arity of 2 with the auto-gen
(defn min
([sym-name kwargs-map symbol-list kwargs-map-1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,17 @@
(is (= 0 (count (executor/grad-arrays exec))))
(is (approx= 1e-4 result (-> (executor/outputs exec) (first))))))

(deftest test-arange-with-inference
(let [arange (sym/arange-with-inference 0)
data (sym/variable "data")
added (sym/+ arange data)
result (range 0 4)
data-tmp (ndarray/zeros [4])
exec (sym/bind added (context/default-context) {"data" data-tmp})]
(executor/forward exec)
(is (= 0 (count (executor/grad-arrays exec))))
(is (approx= 1e-4 result (-> (executor/outputs exec) (first))))))

(deftest test-scalar-pow
(let [data (sym/variable "data")
shape-vec [1 1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
(if (and (number? x) (number? y))
(let [diff (Math/abs (- x y))]
(< diff tolerance))
(reduce (fn [x y] (and x y))
(map #(approx= tolerance %1 %2) x y))))
(and
(= (count x) (count y))
(reduce (fn [x y] (and x y))
(map #(approx= tolerance %1 %2) x y)))))

Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
[org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.symbol :as sym]
[org.apache.clojure-mxnet.test-util :as test-util]
[clojure.spec.alpha :as s])
(:import (org.apache.mxnet Shape NDArrayFuncReturn NDArray)
(scala.collection Map Set)
Expand Down Expand Up @@ -183,3 +184,10 @@
(deftest test-validate
(is (nil? (util/validate! string? "foo" "Not a string!")))
(is (thrown-with-msg? Exception #"Not a string!" (util/validate! ::x 1 "Not a string!"))))

(deftest test-approx=
(let [data1 [1 1 1 1]
data2 [1 1 1 1 9 9 9 9]
data3 [1 1 1 2]]
(is (not (test-util/approx= 1e-9 data1 data2)))
(is (test-util/approx= 2 data1 data3))))
4 changes: 2 additions & 2 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2475,7 +2475,7 @@ def moveaxis(tensor, source, destination):


# pylint: disable= no-member, protected-access, too-many-arguments, redefined-outer-name
def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, ctx=None, dtype=mx_real_t):
"""Returns evenly spaced values within a given interval.

Values are generated within the half-open interval [`start`, `stop`). In other
Expand Down Expand Up @@ -2519,7 +2519,7 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
if ctx is None:
ctx = current_context()
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
dtype=dtype, ctx=str(ctx))
infer_range=infer_range, dtype=dtype, ctx=str(ctx))
# pylint: enable= no-member, protected-access, too-many-arguments


Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2886,7 +2886,7 @@ def full(shape, val, dtype=None, **kwargs):
return _internal._full(shape=shape, dtype=dtype, value=float(val), **kwargs)

# pylint: disable=redefined-outer-name
def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):
def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None):
"""Returns evenly spaced values within a given interval.

Parameters
Expand All @@ -2911,7 +2911,7 @@ def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):
if dtype is None:
dtype = _numpy.float32
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
name=name, dtype=dtype)
infer_range=infer_range, name=name, dtype=dtype)

def histogram(a, bins=10, range=None, **kwargs):
"""Compute the histogram of the input data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,10 @@ object NDArray extends NDArrayBase {
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, ctx: Context = Context.defaultCtx,
dType: DType = Base.MX_REAL_TYPE): NDArray = {
val params = Map("start" -> start, "step" -> step,
"repeat" -> repeat, "ctx" -> ctx.toString, "dtype" -> dType.toString())
def arange(start: Float, stop: Option[Float], step: Float,
repeat: Int, ctx: Context, dType: DType): NDArray = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> false, "ctx" -> ctx.toString, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0)
}
Expand Down
25 changes: 22 additions & 3 deletions scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -955,9 +955,28 @@ object Symbol extends SymbolBase {
* @return Symbol The created Symbol.
*/
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = {
val params = Map("start" -> start, "step" -> step,
"repeat" -> repeat, "dtype" -> dType.toString())
repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = {
arange(start, stop, step, repeat, infer_range = false, name, dType)
}

/**
* Returns evenly spaced values within a given interval.
* stop value can be infered from the output shape,
* which must be known from the rest of the net.
* @param start Start of interval. The default start value is 0.
* @param stop End of interval.
* @param step Spacing between values. The default step size is 1.
* @param repeat Number of times to repeat each element. The default repeat count is 1.
* @param infer_range Infer the stop value from output shape
* @param ctx Device context. Default context is the current default context.
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arange(start: Float, stop: Option[Float], step: Float,
repeat: Int, infer_range: Boolean, name: String,
dType: DType): Symbol = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> infer_range, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
createSymbolGeneral("_arange", name, null, Array.empty[Symbol], fParams)
}
Expand Down
10 changes: 9 additions & 1 deletion src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ struct RangeParam : public dmlc::Parameter<RangeParam> {
dmlc::optional<double> stop;
double step;
int repeat;
bool infer_range;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(RangeParam) {
Expand All @@ -140,6 +141,10 @@ struct RangeParam : public dmlc::Parameter<RangeParam> {
.set_default(1)
.describe("The repeating time of all elements."
" E.g repeat=3, the element a will be repeated three times --> a, a, a.");
DMLC_DECLARE_FIELD(infer_range)
.set_default(false)
.describe("Whether to infer the stop position from the start, step, repeat, and output tensor"
"size.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
Expand Down Expand Up @@ -176,7 +181,7 @@ struct InitOpWithScalarParam : dmlc::Parameter<InitOpWithScalarParam> {
inline void RangeParamParser(nnvm::NodeAttrs* attrs) {
RangeParam param;
param.Init(attrs->dict);
if (!static_cast<bool>(param.stop)) {
if (!static_cast<bool>(param.infer_range) && !static_cast<bool>(param.stop)) {
param.stop = param.start;
param.start = 0;
}
Expand Down Expand Up @@ -471,6 +476,9 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs,
<< "Range does not support step=0, received " << param.step;
CHECK(param.repeat > 0)
<< "Range only supports repeat > 0, received " << param.repeat;
if (param.infer_range && !param.stop.has_value()) {
return false;
}
if (param.step > 0) {
CHECK(param.start < param.stop.value())
<< "Invalid range (start, stop, step) = "
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3555,10 +3555,18 @@ def test_arange():
nd_out = mx.nd.arange(*config, repeat=repeats, dtype=dtype)
assert_almost_equal(np_out, nd_out.asnumpy())

def test_arange_inferstop():
s = mx.sym.arange(start=0, stop=None, infer_range=True)
s = mx.sym.elemwise_add(s, mx.sym.zeros(shape=[5]))
exe = s.bind(ctx=mx.cpu(), args={})
exe.forward()
assert_almost_equal(exe.outputs[0].asnumpy(), np.array([0,1,2,3,4]))

test_basic_val_init(mx.sym.zeros, np.zeros, (3, 4), np.float32)
test_basic_val_init(mx.sym.ones, np.ones, 3, np.int32)
test_basic_val_init(mx.sym.ones, np.ones, (2, 2, 3), np.float16)
test_arange()
test_arange_inferstop()


@with_seed()
Expand Down