Skip to content

Commit

Permalink
Add default parameters for Scala NDArray.arange (apache#13816)
Browse files Browse the repository at this point in the history
* Add default arguments for arange

* Remove redundant tag

* Update test

* Remove infer_range for python ndarray.arange

* Update CONTRIBUTORS.md

* Deprecate infer_range argument in ndarray.arange
  • Loading branch information
kiendang authored and haohuw committed Jun 23, 2019
1 parent 4545b1f commit e8abcfe
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 12 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ List of Contributors
* [Harsh Patel](https://github.com/harshp8l)
* [Xiao Wang](https://github.com/BeyonderXX)
* [Piyush Ghai](https://github.com/piyushghai)
* [Dang Trung Kien](https://github.com/kiendang)
* [Zach Boldyga](https://github.com/zboldyga)
* [Gordon Reid](https://github.com/gordon1992)

Expand Down
7 changes: 5 additions & 2 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2544,7 +2544,7 @@ def moveaxis(tensor, source, destination):


# pylint: disable= no-member, protected-access, too-many-arguments, redefined-outer-name
def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, ctx=None, dtype=mx_real_t):
def arange(start, stop=None, step=1.0, repeat=1, infer_range=None, ctx=None, dtype=mx_real_t):
"""Returns evenly spaced values within a given interval.
Values are generated within the half-open interval [`start`, `stop`). In other
Expand Down Expand Up @@ -2588,10 +2588,13 @@ def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, ctx=None, dt
>>> mx.nd.arange(2, 6, step=2, repeat=3, dtype='int32').asnumpy()
array([2, 2, 2, 4, 4, 4], dtype=int32)
"""
if infer_range is not None:
warnings.warn('`infer_range` argument has been deprecated',
DeprecationWarning)
if ctx is None:
ctx = current_context()
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
infer_range=infer_range, dtype=dtype, ctx=str(ctx))
infer_range=False, dtype=dtype, ctx=str(ctx))
# pylint: enable= no-member, protected-access, too-many-arguments


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -575,15 +575,13 @@ object NDArray extends NDArrayBase {
* @param stop End of interval.
* @param step Spacing between values. The default step size is 1.
* @param repeat Number of times to repeat each element. The default repeat count is 1.
* @param infer_range
* When set to True, infer the stop position from the start, step,
* repeat, and output tensor size.
* @param ctx Device context. Default context is the current default context.
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arange(start: Float, stop: Option[Float], step: Float,
repeat: Int, ctx: Context, dType: DType): NDArray = {
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, ctx: Context = Context.defaultCtx,
dType: DType = Base.MX_REAL_TYPE): NDArray = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> false, "ctx" -> ctx.toString, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,21 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
val stop = start + scala.util.Random.nextFloat() * 100
val step = scala.util.Random.nextFloat() * 4
val repeat = 1
val result = (start.toDouble until stop.toDouble by step.toDouble)
.flatMap(x => Array.fill[Float](repeat)(x.toFloat))
val range = NDArray.arange(start = start, stop = Some(stop), step = step,
repeat = repeat, ctx = Context.cpu(), dType = DType.Float32)
assert(CheckUtils.reldiff(result.toArray, range.toArray) <= 1e-4f)

val result1 = (start.toDouble until stop.toDouble by step.toDouble)
.flatMap(x => Array.fill[Float](repeat)(x.toFloat))
val range1 = NDArray.arange(start = start, stop = Some(stop), step = step,
repeat = repeat)
assert(CheckUtils.reldiff(result1.toArray, range1.toArray) <= 1e-4f)

val result2 = (0.0 until stop.toDouble by step.toDouble)
.flatMap(x => Array.fill[Float](repeat)(x.toFloat))
val range2 = NDArray.arange(stop, step = step, repeat = repeat)
assert(CheckUtils.reldiff(result2.toArray, range2.toArray) <= 1e-4f)

val result3 = 0f to stop by 1f
val range3 = NDArray.arange(stop)
assert(CheckUtils.reldiff(result3.toArray, range3.toArray) <= 1e-4f)
}
}

Expand Down

0 comments on commit e8abcfe

Please sign in to comment.