Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
modify NDArray/Symbol to add infer_range param
Browse files Browse the repository at this point in the history
  • Loading branch information
nswamy committed Aug 19, 2018
1 parent 10851d2 commit bf27a07
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
22 changes: 11 additions & 11 deletions scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -407,30 +407,30 @@ object NDArray extends NDArrayBase {
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, repeat: Int = 1,
ctx: Context = Context.defaultCtx, dType: DType = Base.MX_REAL_TYPE): NDArray = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> false, "ctx" -> ctx.toString, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0)
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, ctx: Context = Context.defaultCtx,
dType: DType = Base.MX_REAL_TYPE): NDArray = {
arange(start, stop, step, repeat, infer_range = false, ctx, dType)
}

/**
* Behaves like arange operator, but infers the stop value from the output shape,
* Return evenly spaced values within a given interval,
* the stop value can be infered from the output shape,
* which must be known from the rest of the net.
* @param start Start of interval. The default start value is 0.
* @param stop End of interval.
* @param step Spacing between values. The default step size is 1.
* @param repeat Number of times to repeat each element. The default repeat count is 1.
* @param infer_range Infer the stop value from output shape
* @param ctx Device context. Default context is the current default context.
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arangeWithInference(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, ctx: Context = Context.defaultCtx,
dType: DType = Base.MX_REAL_TYPE): NDArray = {
def arange(start: Float, stop: Option[Float], step: Float,
repeat: Int, infer_range: Boolean, ctx: Context,
dType: DType): NDArray = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> true, "ctx" -> ctx.toString, "dtype" -> dType.toString())
"infer_range" -> infer_range, "ctx" -> ctx.toString, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0)
}
Expand Down
20 changes: 10 additions & 10 deletions scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -954,29 +954,29 @@ object Symbol extends SymbolBase {
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return Symbol The created Symbol.
*/
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, repeat: Int = 1,
name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> false, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
createSymbolGeneral("_arange", name, null, Array.empty[Symbol], fParams)
def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = {
arange(start, stop, step, repeat, infer_range = false, name, dType)
}

/**
* Behaves like arange operator, but infers the stop value from the output shape,
* Returns evenly spaced values within a given interval.
* stop value can be infered from the output shape,
* which must be known from the rest of the net.
* @param start Start of interval. The default start value is 0.
* @param stop End of interval.
* @param step Spacing between values. The default step size is 1.
* @param repeat Number of times to repeat each element. The default repeat count is 1.
* @param infer_range Infer the stop value from output shape
* @param ctx Device context. Default context is the current default context.
* @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`.
* @return NDArray of evenly spaced values in the specified range.
*/
def arangeWithInference(start: Float, stop: Option[Float] = None, step: Float = 1.0f,
repeat: Int = 1, dType: DType = Base.MX_REAL_TYPE): Symbol = {
def arange(start: Float, stop: Option[Float], step: Float,
repeat: Int, infer_range: Boolean, name: String,
dType: DType): Symbol = {
val params = Map("start" -> start, "step" -> step, "repeat" -> repeat,
"infer_range" -> true, "dtype" -> dType.toString())
"infer_range" -> infer_range, "dtype" -> dType.toString())
val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get)
createSymbolGeneral("_arange", name, null, Array.empty[Symbol], fParams)
}
Expand Down

0 comments on commit bf27a07

Please sign in to comment.