Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
strip leading underscore
Browse files Browse the repository at this point in the history
  • Loading branch information
mdespriee committed Sep 16, 2018
1 parent 7f845ef commit 0b75b87
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,8 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
test("random module is present") {
val loc = NDArray.ones(1, 2)
val scale = NDArray.ones(1, 2) * 2
val rnd = NDArray.random._sample_normal(mu = loc, sigma = scale, shape = Some(Shape(3, 4)))
val rnd2 = NDArray.random._random_normal(loc = Some(1f), scale = Some(2f),
val rnd = NDArray.random.sample_normal(mu = loc, sigma = scale, shape = Some(Shape(3, 4)))
val rnd2 = NDArray.random.random_normal(loc = Some(1f), scale = Some(2f),
shape = Some(Shape(3, 4)))
assert(rnd.shape === Shape(1, 2, 3, 4))
assert(rnd2.shape === Shape(3, 4))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll {
test("random module is present") {
val loc = Symbol.Variable("loc")
val scale = Symbol.Variable("scale")
val rnd = Symbol.random._sample_normal(mu = Some(loc), sigma = Some(scale),
val rnd = Symbol.random.sample_normal(mu = Some(loc), sigma = Some(scale),
shape = Some(Shape(2, 2)))
val rnd2 = Symbol.random._random_normal(loc = Some(1f), scale = Some(2f),
val rnd2 = Symbol.random.random_normal(loc = Some(1f), scale = Some(2f),
shape = Some(Shape(2, 2)))
// scalastyle:off println
println(s"Symbol.random.sample_normal debug info: ${rnd.debugStr}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ private[mxnet] object APIDocGenerator{
def absRndClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
typeSafeClassGen(
getSymbolNDArrayMethods(isSymbol)
.filter(f => f.name.startsWith("_random") || f.name.startsWith("_sample")),
.filter(f => f.name.startsWith("_random") || f.name.startsWith("_sample"))
.map(f => f.copy(name = f.name.stripPrefix("_"))),
FILE_PATH,
if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase",
isSymbol
Expand Down Expand Up @@ -224,8 +225,8 @@ private[mxnet] object APIDocGenerator{
handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, returnType)
new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
}
new absClassFunction(aliasName, desc.value, argList.toList, returnType)
absClassFunction(aliasName, desc.value, argList.toList, returnType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ private[mxnet] object NDArrayMacro {
private def typeSafeRandomAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._

val rndFunctions =
ndarrayFunctions.filter(f => f.name.startsWith("_sample_") || f.name.startsWith("_random_"))
val rndFunctions = ndarrayFunctions
.filter(f => f.name.startsWith("_sample_") || f.name.startsWith("_random_"))
.map(f => f.copy(name = f.name.stripPrefix("_")))

val functionDefs = rndFunctions.map(f => buildTypeSafeFunction(c)(f))
structGeneration(c)(functionDefs, annottees: _*)
Expand Down Expand Up @@ -256,8 +257,8 @@ private[mxnet] object NDArrayMacro {
val argList = argNames zip argTypes map { case (argName, argType) =>
val typeAndOption =
CToScalaUtils.argumentCleaner(argName, argType, "org.apache.mxnet.NDArray")
new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
}
new NDArrayFunction(aliasName, argList.toList)
NDArrayFunction(aliasName, argList.toList)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ private[mxnet] object SymbolImplMacros {
val newSymbolFunctions = {
if (isContrib) symbolFunctions.filter(
func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
else symbolFunctions.filter(!_.name.startsWith("_"))
else symbolFunctions.filterNot(_.name.startsWith("_"))
}

val functionDefs = newSymbolFunctions map { symbolfunction =>
Expand All @@ -92,8 +92,9 @@ private[mxnet] object SymbolImplMacros {
private def typedRandomAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._

val rndFunctions =
symbolFunctions.filter(f => f.name.startsWith("_sample_") || f.name.startsWith("_random_"))
val rndFunctions = symbolFunctions
.filter(f => f.name.startsWith("_sample_") || f.name.startsWith("_random_"))
.map(f => f.copy(name = f.name.stripPrefix("_")))

val functionDefs = rndFunctions.map(f => buildTypedFunction(c)(f))
structGeneration(c)(functionDefs, annottees: _*)
Expand All @@ -117,7 +118,7 @@ private[mxnet] object SymbolImplMacros {
val newSymbolFunctions = {
if (isContrib) symbolFunctions.filter(
func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
else symbolFunctions.filter(f => !f.name.startsWith("_"))
else symbolFunctions.filterNot(_.name.startsWith("_"))
}.filterNot(ele => notGenerated.contains(ele.name))

val functionDefs = newSymbolFunctions.map(f => buildTypedFunction(c)(f))
Expand Down Expand Up @@ -259,8 +260,8 @@ private[mxnet] object SymbolImplMacros {
val argList = argNames zip argTypes map { case (argName, argType) =>
val typeAndOption =
CToScalaUtils.argumentCleaner(argName, argType, "org.apache.mxnet.Symbol")
new SymbolArg(argName, typeAndOption._1, typeAndOption._2)
SymbolArg(argName, typeAndOption._1, typeAndOption._2)
}
new SymbolFunction(aliasName, argList.toList)
SymbolFunction(aliasName, argList.toList)
}
}

0 comments on commit 0b75b87

Please sign in to comment.