Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
NDArray Random API
Browse files Browse the repository at this point in the history
  • Loading branch information
mdespriee committed Sep 10, 2018
1 parent cb574c5 commit b1c26b6
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ object NDArray extends NDArrayBase {

val api = NDArrayAPI

val random = NDArrayRandomAPI

private def addDependency(froms: Array[NDArray], tos: Array[NDArray]): Unit = {
froms.foreach { from =>
val weakRef = new WeakReference(from)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,13 @@ package org.apache.mxnet
object NDArrayAPI extends NDArrayAPIBase {
// TODO: Implement CustomOp for NDArray
}

@AddNDArrayRandomAPIs(false)
/**
* typesafe NDArray random module: NDArray.random._
* Main code will be generated during compile time through Macros
*/
object NDArrayRandomAPI extends NDArrayRandomAPIBase {

}

Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ object SymbolAPI extends SymbolAPIBase {

@AddSymbolRandomAPIs(false)
/**
* typesafe Symbol API: Symbol.random._
* typesafe Symbol random module: Symbol.random._
* Main code will be generated during compile time through Macros
*/
object SymbolRandomAPI extends SymbolRandomAPIBase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -576,4 +576,14 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
assert(arr.internal.toDoubleArray === Array(2d, 2d))
assert(arr.internal.toByteArray === Array(2.toByte, 2.toByte))
}

test("random module is present") {
val loc = NDArray.ones(1, 2)
val scale = NDArray.ones(1, 2) * 2
val rnd = NDArray.random.sample_normal(mu = loc, sigma = scale, shape = Some(Shape(3, 4)))
val rnd2 = NDArray.random.random_normal(loc = Some(1f), scale = Some(2f),
shape = Some(Shape(3, 4)))
assert(rnd.shape === Shape(1, 2, 3, 4))
assert(rnd2.shape === Shape(3, 4))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,6 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite}

class SymbolSuite extends FunSuite with BeforeAndAfterAll {

test("random module - normal") {
val loc = Symbol.Variable("loc")
val scale = Symbol.Variable("scale")
val rnd = Symbol.random.sample_normal(mu = Some(loc), sigma = Some(scale),
shape = Some(Shape(2, 2)))
val rnd2 = Symbol.random.random_normal(loc = Some(1f), scale = Some(2f),
shape = Some(Shape(2, 2)))
// scalastyle:off println
println(s"Symbol.random.normal Symbol args debug info: ${rnd.debugStr}")
println(s"Symbol.random.normal Symbol args debug info: ${rnd2.debugStr}")
// scalastyle:on println
}

test("symbol compose") {
val data = Symbol.Variable("data")

Expand Down Expand Up @@ -85,4 +72,17 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll {
val data2 = data.clone()
assert(data.toJson === data2.toJson)
}

test("random module is present") {
val loc = Symbol.Variable("loc")
val scale = Symbol.Variable("scale")
val rnd = Symbol.random.sample_normal(mu = Some(loc), sigma = Some(scale),
shape = Some(Shape(2, 2)))
val rnd2 = Symbol.random.random_normal(loc = Some(1f), scale = Some(2f),
shape = Some(Shape(2, 2)))
// scalastyle:off println
println(s"Symbol.random.sample_normal debug info: ${rnd.debugStr}")
println(s"Symbol.random.random_normal debug info: ${rnd2.debugStr}")
// scalastyle:on println
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ private[mxnet] object APIDocGenerator{
hashCollector += absClassGen(FILE_PATH, true)
hashCollector += absClassGen(FILE_PATH, false)
hashCollector += absRndClassGen(FILE_PATH, true)
// hashCollector += absClassGen(FILE_PATH, false) // TODO random NDArray
hashCollector += absRndClassGen(FILE_PATH, false)
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
val finalHash = hashCollector.mkString("\n")
Expand All @@ -58,7 +58,8 @@ private[mxnet] object APIDocGenerator{
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
// TODO: Add Filter to the same location in case of refactor
val absFuncs = absClassFunctions.filter(f => f.name.startsWith("sample") || f.name.startsWith("random"))

val absFuncs = absClassFunctions.filter(f => f.name.startsWith("sample_") || f.name.startsWith("random_"))
.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
val defBody = generateAPISignature(absClassFunction, isSymbol)
Expand All @@ -85,8 +86,8 @@ private[mxnet] object APIDocGenerator{
// TODO: Add Filter to the same location in case of refactor
val absFuncs = absClassFunctions
.filterNot(_.name.startsWith("_"))
.filterNot(_.name.startsWith("sample"))
.filterNot(_.name.startsWith("random"))
.filterNot(_.name.startsWith("random_"))
.filterNot(_.name.startsWith("sample_"))
.filterNot(ele => notGenerated.contains(ele.name))
.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation
private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.typeSafeAPIDefs
}

private[mxnet] class AddNDArrayRandomAPIs(isContrib: Boolean) extends StaticAnnotation {
private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.typeSafeRandomAPIDefs
}

private[mxnet] object NDArrayMacro {
case class NDArrayArg(argName: String, argType: String, isOptional : Boolean)
case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg])
Expand All @@ -44,6 +48,9 @@ private[mxnet] object NDArrayMacro {
def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
typeSafeAPIImpl(c)(annottees: _*)
}
def typeSafeRandomAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
typeSafeRandomAPIImpl(c)(annottees: _*)
}
// scalastyle:off havetype

private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule()
Expand Down Expand Up @@ -79,6 +86,22 @@ private[mxnet] object NDArrayMacro {
structGeneration(c)(functionDefs, annottees : _*)
}

/**
* Implementation for Dynamic typed API NDArray.random.<functioname>
*/
private def typeSafeRandomAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._

val rndFunctions =
ndarrayFunctions.filter(f => f.name.startsWith("sample_") || f.name.startsWith("random_"))

val functionDefs = rndFunctions.map(f => buildTypeSafeFunction(c)(f))
structGeneration(c)(functionDefs, annottees: _*)
}

/**
* Implementation for Dynamic typed API NDArray.api.<functioname>
*/
private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = {
import c.universe._

Expand All @@ -91,69 +114,74 @@ private[mxnet] object NDArrayMacro {
val newNDArrayFunctions = {
if (isContrib) ndarrayFunctions.filter(
func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
else ndarrayFunctions.filterNot(_.name.startsWith("_"))
else ndarrayFunctions.filterNot(f => f.name.startsWith("_") &&
!f.name.startsWith("sample_") && !f.name.startsWith("random_"))
}.filterNot(ele => notGenerated.contains(ele.name))

val functionDefs = newNDArrayFunctions.map { ndarrayfunction =>

// Construct argument field
var argDef = ListBuffer[String]()
// Construct Implementation field
var impl = ListBuffer[String]()
impl += "val map = scala.collection.mutable.Map[String, Any]()"
impl += "val args = scala.collection.mutable.ArrayBuffer.empty[NDArray]"
ndarrayfunction.listOfArgs.foreach({ ndarrayarg =>
// var is a special word used to define variable in Scala,
// need to changed to something else in order to make it work
val currArgName = ndarrayarg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case default => ndarrayarg.argName
}
if (ndarrayarg.isOptional) {
argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None"
}
else {
argDef += s"${currArgName} : ${ndarrayarg.argType}"
}
// NDArray arg implementation
val returnType = "org.apache.mxnet.NDArray"

// TODO: Currently we do not add place holder for NDArray
// Example: an NDArray operator like the following format
// nd.foo(arg1: NDArray(required), arg2: NDArray(Optional), arg3: NDArray(Optional)
// If we place nd.foo(arg1, arg3 = arg3), do we need to add place holder for arg2?
// What it should be?
val base =
if (ndarrayarg.argType.equals(returnType)) {
s"args += $currArgName"
} else if (ndarrayarg.argType.equals(s"Array[$returnType]")){
s"args ++= $currArgName"
} else {
"map(\"" + ndarrayarg.argName + "\") = " + currArgName
}
impl.append(
if (ndarrayarg.isOptional) s"if (!$currArgName.isEmpty) $base.get"
else base
)
})
// add default out parameter
argDef += "out : Option[NDArray] = None"
impl += "if (!out.isEmpty) map(\"out\") = out.get"
// scalastyle:off
impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)"
// scalastyle:on
// Combine and build the function string
val returnType = "org.apache.mxnet.NDArrayFuncReturn"
var finalStr = s"def ${ndarrayfunction.name}"
finalStr += s" (${argDef.mkString(",")}) : $returnType"
finalStr += s" = {${impl.mkString("\n")}}"
c.parse(finalStr).asInstanceOf[DefDef]
}
val functionDefs = newNDArrayFunctions.map(f => buildTypeSafeFunction(c)(f))

structGeneration(c)(functionDefs, annottees : _*)
}

private def buildTypeSafeFunction(c: blackbox.Context)
(ndarrayfunction: NDArrayFunction): c.universe.DefDef = {
import c.universe._

// Construct argument field
var argDef = ListBuffer[String]()
// Construct Implementation field
var impl = ListBuffer[String]()
impl += "val map = scala.collection.mutable.Map[String, Any]()"
impl += "val args = scala.collection.mutable.ArrayBuffer.empty[NDArray]"
ndarrayfunction.listOfArgs.foreach({ ndarrayarg =>
// var is a special word used to define variable in Scala,
// need to changed to something else in order to make it work
val currArgName = ndarrayarg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case default => ndarrayarg.argName
}
if (ndarrayarg.isOptional) {
argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None"
}
else {
argDef += s"${currArgName} : ${ndarrayarg.argType}"
}
// NDArray arg implementation
val returnType = "org.apache.mxnet.NDArray"

// TODO: Currently we do not add place holder for NDArray
// Example: an NDArray operator like the following format
// nd.foo(arg1: NDArray(required), arg2: NDArray(Optional), arg3: NDArray(Optional)
// If we place nd.foo(arg1, arg3 = arg3), do we need to add place holder for arg2?
// What it should be?
val base =
if (ndarrayarg.argType.equals(returnType)) {
s"args += $currArgName"
} else if (ndarrayarg.argType.equals(s"Array[$returnType]")) {
s"args ++= $currArgName"
} else {
"map(\"" + ndarrayarg.argName + "\") = " + currArgName
}
impl.append(
if (ndarrayarg.isOptional) s"if (!$currArgName.isEmpty) $base.get"
else base
)
})
// add default out parameter
argDef += "out : Option[NDArray] = None"
impl += "if (!out.isEmpty) map(\"out\") = out.get"
// scalastyle:off
impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)"
// scalastyle:on
// Combine and build the function string
val returnType = "org.apache.mxnet.NDArrayFuncReturn"
var finalStr = s"def ${ndarrayfunction.name}"
finalStr += s" (${argDef.mkString(",")}) : $returnType"
finalStr += s" = {${impl.mkString("\n")}}"
c.parse(finalStr).asInstanceOf[DefDef]
}

private def structGeneration(c: blackbox.Context)
(funcDef : List[c.universe.DefDef], annottees: c.Expr[Any]*)
: c.Expr[Any] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ private[mxnet] object SymbolImplMacros {
import c.universe._

val rndFunctions =
symbolFunctions.filter(f => f.name.startsWith("sample") || f.name.startsWith("random"))
symbolFunctions.filter(f => f.name.startsWith("sample_") || f.name.startsWith("random_"))

val functionDefs = rndFunctions.map(f => buildTypedFunction(c)(f))
structGeneration(c)(functionDefs, annottees: _*)
Expand All @@ -118,7 +118,7 @@ private[mxnet] object SymbolImplMacros {
if (isContrib) symbolFunctions.filter(
func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
else symbolFunctions.filter(f => !f.name.startsWith("_") &&
!f.name.startsWith("sample") && !f.name.startsWith("random"))
!f.name.startsWith("sample_") && !f.name.startsWith("random_"))
}.filterNot(ele => notGenerated.contains(ele.name))

val functionDefs = newSymbolFunctions.map(f => buildTypedFunction(c)(f))
Expand Down

0 comments on commit b1c26b6

Please sign in to comment.