-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Allow stop of arange to be inferred from dims. #12064
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,7 +135,17 @@ | |
([start stop {:keys [step repeat dtype] | ||
:or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE} | ||
:as opts}] | ||
(Symbol/arange (float start) ($/option (float stop)) step repeat nil dtype)) | ||
(Symbol/arange (float start) ($/option (float stop)) step repeat false nil dtype)) | ||
([start stop] | ||
(arange start stop {}))) | ||
|
||
(defn arange-with-inference | ||
"Behaves like arange operator, but infers the stop value from the output shape, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here - Great job adding the clojure functions. If you want to add the corresponding test that would be awesome too https://github.com/apache/incubator-mxnet/blob/master/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj#L214. It can also be done in a follow up PR if that works better 😸 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gigasquid happy to try to do that! i haven't ever used Clojure or Scala before. But I've run into a problem even getting the first step,
This seems to be related to https://issues.scala-lang.org/browse/SI-9103, but that issue is still open. I have no idea whats going on or how to make progress. I reran the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @taliesinb I'm impressed that you jumped in there on the Clojure and Scala 💯 - From the issue it seems like it is the JDK you are using. Using JDK 8 should solve the problems. If you have multiple versions of the JDK installed, you should just be able to switch by using an export of the right JAVA_HOME see here. Give it a try and see how it goes. If you don't want to hold up this PR, I'd be happy to assist on a follow up PR if you'd like 😸 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gigasquid great that worked! thanks for the help! I'll keep you posted as I (hopefully) make progress with this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @taliesinb are you making a new PR for Clojure or do you want to make changes to this one ? This is good for Scala APIs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm @gigasquid I'm finding the instructions in the
I would find this easier to understand if it was very explicit, such as "replace X with Y in section Z". Here is what my project.clj contained out of the box:
At this point, not knowing what to replace with what, I read the section "Cloning the repo and running from source", which mentions uncommenting rather than replacing. That section is also a bit confusing:
Which line? What is a "native version of the line"? Perhaps it could say "you will need to find and uncomment the appropriate line in the dependencies section of the project.clj file, and comment the rest". We could also make the project.clj section clearer so its more obvious what to do:
Now, while the instructions could be a bit clearer, I did figure out the point eventually, and so I tried adding this line and commenting the rest:
After running
That WARN makes it sound like it's not using the library I built earlier using How should I fix this? Also, an ergonomics question: the test suite takes a while to run. With Python, it was very easy to just run the new tests I added using e.g. Thanks in advance for your help! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @nswamy I'm now skeptical of my Scala changes. For example, I'm not sure that the additions to The new functionality should work in the symbolic context, however. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @taliesinb Thanks for the feedback on the wording. I'll update it to be more clear. It seems like you are doing everything exactly right. Once you do a scalainstall it will install locally in your maven a new Thanks again for the feedback and let me know if you have any other issues There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gigasquid thanks for the info. I'll get back to this tomorrow hopefully. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gigasquid ok we're good to go. I removed the pointless imperative function version of However, in doing this, I think I've picked up a problem with For example, try change the test starting on line 200 to the following:
(I've introduced the 9 9 9 9 9 9 here). This test still passes. I've reported the issue here: #12320, and fixed it in this PR. It doesn't produce any regressions, luckily! If my new test looks good to you, we should be ready to merge! |
||
which must be known from the rest of the net." | ||
([start stop {:keys [step repeat dtype] | ||
:or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE} | ||
:as opts}] | ||
(Symbol/arange (float start) ($/option (float stop)) step repeat true nil dtype)) | ||
([start stop] | ||
(arange start stop {}))) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -407,11 +407,30 @@ object NDArray extends NDArrayBase { | |
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`. | ||
* @return NDArray of evenly spaced values in the specified range. | ||
*/ | ||
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, | ||
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, repeat: Int = 1, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @taliesinb Thanks for making this change. There is a compile error and CI is breaking. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lanking520 FYI There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @nswamy oh thanks! my bad. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @nswamy Looks good! |
||
ctx: Context = Context.defaultCtx, dType: DType = Base.MX_REAL_TYPE): NDArray = { | ||
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat, | ||
"infer_range" -> false, "ctx" -> ctx.toString, "dtype" -> dType.toString()) | ||
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get) | ||
NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0) | ||
} | ||
|
||
/** | ||
* Behaves like arange operator, but infers the stop value from the output shape, | ||
* which must be known from the rest of the net. | ||
* @param start Start of interval. The default start value is 0. | ||
* @param stop End of interval. | ||
* @param step Spacing between values. The default step size is 1. | ||
* @param repeat Number of times to repeat each element. The default repeat count is 1. | ||
* @param ctx Device context. Default context is the current default context. | ||
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`. | ||
* @return NDArray of evenly spaced values in the specified range. | ||
*/ | ||
def arangeWithInference(start: Float, stop: Option[Float] = None, step: Float = 1.0f, | ||
repeat: Int = 1, ctx: Context = Context.defaultCtx, | ||
dType: DType = Base.MX_REAL_TYPE): NDArray = { | ||
val params = Map("start" -> start, "step" -> step, | ||
"repeat" -> repeat, "ctx" -> ctx.toString, "dtype" -> dType.toString()) | ||
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat, | ||
"infer_range" -> true, "ctx" -> ctx.toString, "dtype" -> dType.toString()) | ||
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get) | ||
NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job adding the clojure function with nice documentation. 👍 If you are feeling up to it you could also add the corresponding test for it here https://github.com/apache/incubator-mxnet/blob/master/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj#L141