Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add default parameters for Scala NDArray.arange #13816

Merged
merged 8 commits into from
Mar 6, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -405,15 +405,13 @@ object NDArray extends NDArrayBase {
* @param stop End of interval.
* @param step Spacing between values. The default step size is 1.
* @param repeat Number of times to repeat each element. The default repeat count is 1.
* @param infer_range
* When set to True, infer the stop position from the start, step,
* repeat, and output tensor size.
* @param ctx Device context. Default context is the current default context.
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arange(start: Float, stop: Option[Float], step: Float,
repeat: Int, ctx: Context, dType: DType): NDArray = {
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless you are certain that this param is not going to be used, please place it at the end of the param list and rename it as inferRange follows Java camel case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not being used now. Do you mean we still include it now in case it'll be used later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to #12064 (comment) infer_range is only applicable to symbolic API and was meant to be removed from ndarray API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add @taliesinb here to bring more context since I am not very familiar with it

repeat: Int = 1, ctx: Context = Context.defaultCtx,
dType: DType = Base.MX_REAL_TYPE): NDArray = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> false, "ctx" -> ctx.toString, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,21 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
val stop = start + scala.util.Random.nextFloat() * 100
val step = scala.util.Random.nextFloat() * 4
val repeat = 1
val result = (start.toDouble until stop.toDouble by step.toDouble)
.flatMap(x => Array.fill[Float](repeat)(x.toFloat))
val range = NDArray.arange(start = start, stop = Some(stop), step = step,
repeat = repeat, ctx = Context.cpu(), dType = DType.Float32)
assert(CheckUtils.reldiff(result.toArray, range.toArray) <= 1e-4f)

val result1 = (start.toDouble until stop.toDouble by step.toDouble)
.flatMap(x => Array.fill[Float](repeat)(x.toFloat))
val range1 = NDArray.arange(start = start, stop = Some(stop), step = step,
repeat = repeat)
assert(CheckUtils.reldiff(result1.toArray, range1.toArray) <= 1e-4f)

val result2 = (0.0 until stop.toDouble by step.toDouble)
.flatMap(x => Array.fill[Float](repeat)(x.toFloat))
val range2 = NDArray.arange(stop, step = step, repeat = repeat)
assert(CheckUtils.reldiff(result2.toArray, range2.toArray) <= 1e-4f)

val result3 = 0f to stop by 1f
val range3 = NDArray.arange(stop)
assert(CheckUtils.reldiff(result3.toArray, range3.toArray) <= 1e-4f)
}
}

Expand Down