Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Allow stop of arange to be inferred from dims. #12064

Merged
merged 5 commits into from
Aug 24, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,17 @@
([start stop {:keys [step repeat ctx dtype]
:or {step (float 1) repeat (int 1) ctx (mx-context/default-context) dtype base/MX_REAL_TYPE}
:as opts}]
(NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype))
(NDArray/arange (float start) ($/option (float stop)) step repeat false ctx dtype))
([start stop]
(arange start stop {})))

(defn arange-with-inference
"Behaves like arange operator, but infers the stop value from the output shape,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job adding the clojure function with nice documentation. 👍 If you are feeling up to it you could also add the corresponding test for it here https://github.com/apache/incubator-mxnet/blob/master/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj#L141

which must be known from the rest of the net."
([start stop {:keys [step repeat ctx dtype]
:or {step (float 1) repeat (int 1) ctx (mx-context/default-context) dtype base/MX_REAL_TYPE}
:as opts}]
(NDArray/arange (float start) ($/option (float stop)) step repeat true ctx dtype))
([start stop]
(arange start stop {})))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,17 @@
([start stop {:keys [step repeat dtype]
:or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
:as opts}]
(Symbol/arange (float start) ($/option (float stop)) step repeat nil dtype))
(Symbol/arange (float start) ($/option (float stop)) step repeat false nil dtype))
([start stop]
(arange start stop {})))

(defn arange-with-inference
"Behaves like arange operator, but infers the stop value from the output shape,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here - Great job adding the clojure functions. If you want to add the corresponding test that would be awesome too https://github.com/apache/incubator-mxnet/blob/master/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj#L214. It can also be done in a follow up PR if that works better 😸

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gigasquid happy to try to do that! i haven't ever used Clojure or Scala before. But I've run into a problem even getting the first step, make scalapkg, to work on macOS. The make initially failed, comlpaining it couldn't find the mvn executable. I assumed that was maven, brew install maven had me first brew install java. Then make scalapkg seemed to be happy, and downloaded a bunch of stuff (including scala). But it failed with this:

[INFO] /Users/taliesinb/git/MXNet/scala-package/init/src/main/scala:-1: info: compiling
[INFO] Compiling 2 source files to /Users/taliesinb/git/MXNet/scala-package/init/target/classes at 1534710433224
Downloading from central: https://repo.maven.apache.org/maven2/org/scala-lang/modules/scala-xml_2.11/1.0.4/scala-xml_2.11-1.0.4.jar
Downloaded from central: https://repo.maven.apache.org/maven2/org/scala-lang/modules/scala-xml_2.11/1.0.4/scala-xml_2.11-1.0.4.jar (648 kB at 755 kB/s)
Downloading from central: https://repo.maven.apache.org/maven2/org/scala-lang/scala-library/2.11.4/scala-library-2.11.4.jar
Downloaded from central: https://repo.maven.apache.org/maven2/org/scala-lang/scala-library/2.11.4/scala-library-2.11.4.jar (5.5 MB at 1.3 MB/s)
Downloading from central: https://repo.maven.apache.org/maven2/org/scala-lang/scala-library/2.11.6/scala-library-2.11.6.jar
Downloaded from central: https://repo.maven.apache.org/maven2/org/scala-lang/scala-library/2.11.6/scala-library-2.11.6.jar (5.6 MB at 1.2 MB/s)
[INFO] compiler plugin: BasicArtifact(org.scalamacros,paradise_2.11.8,2.1.0)
Downloading from central: https://repo.maven.apache.org/maven2/org/scalamacros/paradise_2.11.8/2.1.0/paradise_2.11.8-2.1.0.jar
Downloaded from central: https://repo.maven.apache.org/maven2/org/scalamacros/paradise_2.11.8/2.1.0/paradise_2.11.8-2.1.0.jar (271 kB at 397 kB/s)
[ERROR] error: scala.reflect.internal.MissingRequirementError: object java.lang.Object in compiler mirror not found.
[ERROR] 	at scala.reflect.internal.MissingRequirementError$.signal(MissingRequirementError.scala:17)
[ERROR] 	at scala.reflect.internal.MissingRequirementError$.notFound(MissingRequirementError.scala:18)
[INFO] 	at scala.reflect.internal.Mirrors$RootsBase.getModuleOrClass(Mirrors.scala:53)

This seems to be related to https://issues.scala-lang.org/browse/SI-9103, but that issue is still open. I have no idea whats going on or how to make progress. I reran the make scalapkg to no effect, here's a gist with the full output: https://gist.github.com/taliesinb/d0f09e9f0202c3983298511383542f59. Do you have any suggestions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@taliesinb I'm impressed that you jumped in there on the Clojure and Scala 💯 - From the issue it seems like it is the JDK you are using. Using JDK 8 should solve the problems. If you have multiple versions of the JDK installed, you should just be able to switch by using an export of the right JAVA_HOME see here. Give it a try and see how it goes. If you don't want to hold up this PR, I'd be happy to assist on a follow up PR if you'd like 😸

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gigasquid great that worked! thanks for the help! I'll keep you posted as I (hopefully) make progress with this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@taliesinb are you making a new PR for Clojure or do you want to make changes to this one ? This is good for Scala APIs.
Thanks for the great work 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm @gigasquid I'm finding the instructions in the README.md file a little unclear in places. For example, under "Build from MXNET Source", I find this instruction a bit cryptic:

then replace the correct jar for your architecture in the project.clj, example [org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.3.0-SNAPSHOT"]

I would find this easier to understand if it was very explicit, such as "replace X with Y in section Z".

Here is what my project.clj contained out of the box:

                 ;; Jars from Nexus
                 ;[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.2.1"]
                 ;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.2.1"]
                 ;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-gpu "1.2.1"]

                 ;;; CI
                 [org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.3.0-SNAPSHOT"]

At this point, not knowing what to replace with what, I read the section "Cloning the repo and running from source", which mentions uncommenting rather than replacing. That section is also a bit confusing:

you will need to replace the native version of the line in the project dependencies with your configuration.

Which line? What is a "native version of the line"? Perhaps it could say "you will need to find and uncomment the appropriate line in the dependencies section of the project.clj file, and comment the rest". We could also make the project.clj section clearer so its more obvious what to do:

                 ;; default behavior, to be used by the CI bot on github; comment this line and
                 ;; uncomment the appropriate line in one of the other sections
                 [org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.3.0-SNAPSHOT"]

                 ;; use a prebuilt JAR from Nexus
                 ;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.2.1"]
                 ;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-gpu "1.2.1"]
                 ;[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.2.1"]

                 ;;; build a local JAR from source
                 ;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-cpu "1.3.0-SNAPSHOT"]
                 ;[org.apache.mxnet/mxnet-full_2.11-linux-x86_64-gpu "1.3.0-SNAPSHOT"]
                 ;[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.3.0-SNAPSHOT"]

Now, while the instructions could be a bit clearer, I did figure out the point eventually, and so I tried adding this line and commenting the rest:

[org.apache.mxnet/mxnet-full_2.11-osx-x86_64-cpu "1.3.0-SNAPSHOT"]

After running lein clean and lein test I get this:

Generating symbol file
INFO  MXNetJVM: Try loading mxnet-scala from native path.
INFO  MXNetJVM: Try loading mxnet-scala-osx-x86_64-gpu from native path.
INFO  MXNetJVM: Try loading mxnet-scala-osx-x86_64-cpu from native path.
WARN  MXNetJVM: MXNet Scala native library not found in path. Copying native library from the archive. Consider installing the library somewhere in the path (for Windows: PATH, for Linux: LD_LIBRARY_PATH), or specifying by Java cmd option -Djava.library.path=[lib path].
INFO  org.apache.mxnet.util.NativeLibraryLoader: Replaced .dylib with .jnilib
INFO  org.apache.mxnet.util.NativeLibraryLoader: Loading libmxnet-scala.jnilib from /lib/native/ copying to mxnet-scala
[2

That WARN makes it sound like it's not using the library I built earlier using make scalainstall, which will mean I can't actually test my new functionality! Wasn't make scalainstall supposed to make the MXNet scala libraries available for everyone on my system?

How should I fix this?

Also, an ergonomics question: the test suite takes a while to run. With Python, it was very easy to just run the new tests I added using e.g. nosetests -v tests/python/unittest/test_operator.py. Is there a similar incantation for Clojure?

Thanks in advance for your help!

Copy link
Contributor Author

@taliesinb taliesinb Aug 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nswamy I'm now skeptical of my Scala changes. For example, I'm not sure that the additions to arange in NDArray.scala are correct. My concern is that the only way that you can even use the new inference feature is via backward inference, because the inference is based on the output shape of the tensor produced by arange, which must be inferred from a different part of the graph. So unless I'm missing something, calling the imperative arange function with infer_range = true will always fail as it has to produce an NDArray immediately, but this is not possible because backward inference is only relevant for symbols.

The new functionality should work in the symbolic context, however.
EDIT: to answer your original question, I'd prefer to add a Scala test to verify this last claim! If i did the wrong thing before... I don't trust myself now. Plus, if you agree, I should delete the imperative version of the new arange functionality from all language APIs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@taliesinb Thanks for the feedback on the wording. I'll update it to be more clear. It seems like you are doing everything exactly right. Once you do a scalainstall it will install locally in your maven a new 1.3.0-SNAPSHOT. Since you updated your project.clj to use this, it will load up the updated jar. The WARN is again misleading and can be improved, but it should be working :)
As far as running just one test, you can certainly do that with lein test :only org.apache.clojure-mxnet.ndarray-test and lein test :only org.apache.clojure-mxnet.operator-test.

Thanks again for the feedback and let me know if you have any other issues

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gigasquid thanks for the info. I'll get back to this tomorrow hopefully.

Copy link
Contributor Author

@taliesinb taliesinb Aug 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gigasquid ok we're good to go. I removed the pointless imperative function version of arange-with-inference, so I only had to add a test to operator_test.clj.

However, in doing this, I think I've picked up a problem with approx=, in which it incorrectly returns true if one of the comparisands (is that a word??) is shorter than the other, and differs in the remaining elements that the other does not have.

For example, try change the test starting on line 200 to the following:

(deftest ones
  (let [ones (sym/ones [2 2])
        exec (sym/simple-bind ones (context/default-context))]
    (is (approx= 1e-4
                 [1 1 1 1 9 9 9 9 9 9]
                 (-> exec (executor/forward) (executor/outputs) (first))))))

(I've introduced the 9 9 9 9 9 9 here). This test still passes.

I've reported the issue here: #12320, and fixed it in this PR. It doesn't produce any regressions, luckily!

If my new test looks good to you, we should be ready to merge!

which must be known from the rest of the net."
([start stop {:keys [step repeat dtype]
:or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
:as opts}]
(Symbol/arange (float start) ($/option (float stop)) step repeat true nil dtype))
([start stop]
(arange start stop {})))

Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2475,7 +2475,7 @@ def moveaxis(tensor, source, destination):


# pylint: disable= no-member, protected-access, too-many-arguments, redefined-outer-name
def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, ctx=None, dtype=mx_real_t):
"""Returns evenly spaced values within a given interval.

Values are generated within the half-open interval [`start`, `stop`). In other
Expand Down Expand Up @@ -2519,7 +2519,7 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t):
if ctx is None:
ctx = current_context()
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
dtype=dtype, ctx=str(ctx))
infer_range=infer_range, dtype=dtype, ctx=str(ctx))
# pylint: enable= no-member, protected-access, too-many-arguments


Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2886,7 +2886,7 @@ def full(shape, val, dtype=None, **kwargs):
return _internal._full(shape=shape, dtype=dtype, value=float(val), **kwargs)

# pylint: disable=redefined-outer-name
def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):
def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None):
"""Returns evenly spaced values within a given interval.

Parameters
Expand All @@ -2911,7 +2911,7 @@ def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):
if dtype is None:
dtype = _numpy.float32
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
name=name, dtype=dtype)
infer_range=infer_range, name=name, dtype=dtype)

def histogram(a, bins=10, range=None, **kwargs):
"""Compute the histogram of the input data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,30 @@ object NDArray extends NDArrayBase {
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, repeat: Int = 1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@taliesinb Thanks for making this change. There is a compile error and CI is breaking.
I also want to slightly change this, I fixed the compile error and modified NDArray/Symbol and pushed a commit to your branch

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nswamy oh thanks! my bad.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nswamy Looks good!

ctx: Context = Context.defaultCtx, dType: DType = Base.MX_REAL_TYPE): NDArray = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> false, "ctx" -> ctx.toString, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0)
}

/**
* Behaves like arange operator, but infers the stop value from the output shape,
* which must be known from the rest of the net.
* @param start Start of interval. The default start value is 0.
* @param stop End of interval.
* @param step Spacing between values. The default step size is 1.
* @param repeat Number of times to repeat each element. The default repeat count is 1.
* @param ctx Device context. Default context is the current default context.
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arangeWithInference(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, ctx: Context = Context.defaultCtx,
dType: DType = Base.MX_REAL_TYPE): NDArray = {
val params = Map("start" -> start, "step" -> step,
"repeat" -> repeat, "ctx" -> ctx.toString, "dtype" -> dType.toString())
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> true, "ctx" -> ctx.toString, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0)
}
Expand Down
27 changes: 23 additions & 4 deletions scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -954,10 +954,29 @@ object Symbol extends SymbolBase {
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return Symbol The created Symbol.
*/
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = {
val params = Map("start" -> start, "step" -> step,
"repeat" -> repeat, "dtype" -> dType.toString())
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, repeat: Int = 1,
name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> false, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
createSymbolGeneral("_arange", name, null, Array.empty[Symbol], fParams)
}

/**
* Behaves like arange operator, but infers the stop value from the output shape,
* which must be known from the rest of the net.
* @param start Start of interval. The default start value is 0.
* @param stop End of interval.
* @param step Spacing between values. The default step size is 1.
* @param repeat Number of times to repeat each element. The default repeat count is 1.
* @param ctx Device context. Default context is the current default context.
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arangeWithInference(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, dType: DType = Base.MX_REAL_TYPE): Symbol = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> true, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
createSymbolGeneral("_arange", name, null, Array.empty[Symbol], fParams)
}
Expand Down
10 changes: 9 additions & 1 deletion src/operator/tensor/init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ struct RangeParam : public dmlc::Parameter<RangeParam> {
dmlc::optional<double> stop;
double step;
int repeat;
bool infer_range;
std::string ctx;
int dtype;
DMLC_DECLARE_PARAMETER(RangeParam) {
Expand All @@ -140,6 +141,10 @@ struct RangeParam : public dmlc::Parameter<RangeParam> {
.set_default(1)
.describe("The repeating time of all elements."
" E.g repeat=3, the element a will be repeated three times --> a, a, a.");
DMLC_DECLARE_FIELD(infer_range)
.set_default(false)
.describe("Whether to infer the stop position from the start, step, repeat, and output tensor"
"size.");
DMLC_DECLARE_FIELD(ctx)
.set_default("")
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
Expand Down Expand Up @@ -176,7 +181,7 @@ struct InitOpWithScalarParam : dmlc::Parameter<InitOpWithScalarParam> {
inline void RangeParamParser(nnvm::NodeAttrs* attrs) {
RangeParam param;
param.Init(attrs->dict);
if (!static_cast<bool>(param.stop)) {
if (!static_cast<bool>(param.infer_range) && !static_cast<bool>(param.stop)) {
param.stop = param.start;
param.start = 0;
}
Expand Down Expand Up @@ -471,6 +476,9 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs,
<< "Range does not support step=0, received " << param.step;
CHECK(param.repeat > 0)
<< "Range only supports repeat > 0, received " << param.repeat;
if (param.infer_range && !param.stop.has_value()) {
return false;
}
if (param.step > 0) {
CHECK(param.start < param.stop.value())
<< "Invalid range (start, stop, step) = "
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3555,10 +3555,18 @@ def test_arange():
nd_out = mx.nd.arange(*config, repeat=repeats, dtype=dtype)
assert_almost_equal(np_out, nd_out.asnumpy())

def test_arange_inferstop():
s = mx.sym.arange(start=0, stop=None, infer_range=True)
s = mx.sym.elemwise_add(s, mx.sym.zeros(shape=[5]))
exe = s.bind(ctx=mx.cpu(), args={})
exe.forward()
assert_almost_equal(exe.outputs[0].asnumpy(), np.array([0,1,2,3,4]))

test_basic_val_init(mx.sym.zeros, np.zeros, (3, 4), np.float32)
test_basic_val_init(mx.sym.ones, np.ones, 3, np.int32)
test_basic_val_init(mx.sym.ones, np.ones, (2, 2, 3), np.float16)
test_arange()
test_arange_inferstop()


@with_seed()
Expand Down