diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala index b2a53fd9f2dd..bb9518d51f1e 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala @@ -153,3 +153,21 @@ private[mxnet] object Base { } class MXNetError(val err: String) extends Exception(err) + +// Some type-classes to ease the work in Symbol.random and NDArray.random modules + +class SymbolOrScalar[T](val isScalar: Boolean) +object SymbolOrScalar { + def apply[T](implicit ev: SymbolOrScalar[T]): SymbolOrScalar[T] = ev + implicit object FloatWitness extends SymbolOrScalar[Float](true) + implicit object IntWitness extends SymbolOrScalar[Int](true) + implicit object SymbolWitness extends SymbolOrScalar[Symbol](false) +} + +class NDArrayOrScalar[T](val isScalar: Boolean) +object NDArrayOrScalar { + def apply[T](implicit ev: NDArrayOrScalar[T]): NDArrayOrScalar[T] = ev + implicit object FloatWitness extends NDArrayOrScalar[Float](true) + implicit object IntWitness extends NDArrayOrScalar[Int](true) + implicit object NDArrayWitness extends NDArrayOrScalar[NDArray](false) +} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 3a0c3c11f16a..125958150b72 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -40,6 +40,7 @@ object NDArray extends NDArrayBase { private val functions: Map[String, NDArrayFunction] = initNDArrayModule() val api = NDArrayAPI + val random = NDArrayRandomAPI private def addDependency(froms: Array[NDArray], tos: Array[NDArray]): Unit = { froms.foreach { from => diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala index 1d8551c1b1e5..024fed1c4ba6 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala @@ -15,11 +15,22 @@ * limitations under the License. */ package org.apache.mxnet -@AddNDArrayAPIs(false) + /** * typesafe NDArray API: NDArray.api._ * Main code will be generated during compile time through Macros */ +@AddNDArrayAPIs(false) object NDArrayAPI extends NDArrayAPIBase { // TODO: Implement CustomOp for NDArray } + +/** + * typesafe NDArray random module: NDArray.random._ + * Main code will be generated during compile time through Macros + */ +@AddNDArrayRandomAPIs(false) +object NDArrayRandomAPI extends NDArrayRandomAPIBase { + +} + diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 01349a689b6c..29885fc723cd 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -842,6 +842,7 @@ object Symbol extends SymbolBase { private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3) val api = SymbolAPI + val random = SymbolRandomAPI def pow(sym1: Symbol, sym2: Symbol): Symbol = { Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2)) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala index 1bfb0559cf96..f166de11ea52 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala @@ -19,11 +19,11 @@ package org.apache.mxnet import scala.collection.mutable -@AddSymbolAPIs(false) /** * typesafe Symbol API: Symbol.api._ * Main code will be generated during compile time through Macros */ +@AddSymbolAPIs(false) object SymbolAPI extends SymbolAPIBase { def Custom (op_type : String, kwargs : mutable.Map[String, Any], name : String = null, attr : Map[String, String] = null) : Symbol = { @@ -32,3 +32,13 @@ object SymbolAPI extends SymbolAPIBase { Symbol.createSymbolGeneral("Custom", name, attr, Seq(), map.toMap) } } + +/** + * typesafe Symbol random module: Symbol.random._ + * Main code will be generated during compile time through Macros + */ +@AddSymbolRandomAPIs(false) +object SymbolRandomAPI extends SymbolRandomAPIBase { + +} + diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index 5d88bb39e502..7992a0ed867b 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -576,4 +576,21 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { assert(arr.internal.toDoubleArray === Array(2d, 2d)) assert(arr.internal.toByteArray === Array(2.toByte, 2.toByte)) } + + test("NDArray random module is generated properly") { + val lam = NDArray.ones(1, 2) + val rnd = NDArray.random.poisson(lam = Some(lam), shape = Some(Shape(3, 4))) + val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3, 4))) + assert(rnd.shape === Shape(1, 2, 3, 4)) + assert(rnd2.shape === Shape(3, 4)) + } + + test("NDArray random module is generated properly - special case of 'normal'") { + val mu = NDArray.ones(1, 2) + val sigma = NDArray.ones(1, 2) * 2 + val rnd = NDArray.random.normal(mu = Some(mu), sigma = Some(sigma), shape = Some(Shape(3, 4))) + val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(3, 4))) + assert(rnd.shape === Shape(1, 2, 3, 4)) + assert(rnd2.shape === Shape(3, 4)) + } } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala index ebb61d7d4bfb..d134c83ff7e7 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala @@ -20,6 +20,7 @@ package org.apache.mxnet import org.scalatest.{BeforeAndAfterAll, FunSuite} class SymbolSuite extends FunSuite with BeforeAndAfterAll { + test("symbol compose") { val data = Symbol.Variable("data") @@ -71,4 +72,25 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll { val data2 = data.clone() assert(data.toJson === data2.toJson) } + + test("Symbol random module is generated properly") { + val lam = Symbol.Variable("lam") + val rnd = Symbol.random.poisson(lam = Some(lam), shape = Some(Shape(2, 2))) + val rnd2 = Symbol.random.poisson(lam = Some(1f), shape = Some(Shape(2, 2))) + // scalastyle:off println + println(s"Symbol.random.poisson debug info: ${rnd.debugStr}") + println(s"Symbol.random.poisson debug info: ${rnd2.debugStr}") + // scalastyle:on println + } + + test("Symbol random module is generated properly - special case of 'normal'") { + val loc = Symbol.Variable("loc") + val scale = Symbol.Variable("scale") + val rnd = Symbol.random.normal(mu = Some(loc), sigma = Some(scale), shape = Some(Shape(2, 2))) + val rnd2 = Symbol.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(2, 2))) + // scalastyle:off println + println(s"Symbol.random.sample_normal debug info: ${rnd.debugStr}") + println(s"Symbol.random.random_normal debug info: ${rnd2.debugStr}") + // scalastyle:on println + } } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index ce12dc7cd5a0..97cd18a5b337 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -27,13 +27,15 @@ import scala.collection.mutable.ListBuffer * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala * The code will be executed during Macros stage and file live in Core stage */ -private[mxnet] object APIDocGenerator extends GeneratorBase { +private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { def main(args: Array[String]): Unit = { val FILE_PATH = args(0) val hashCollector = ListBuffer[String]() hashCollector += typeSafeClassGen(FILE_PATH, true) hashCollector += typeSafeClassGen(FILE_PATH, false) + hashCollector += typeSafeRandomClassGen(FILE_PATH, true) + hashCollector += typeSafeRandomClassGen(FILE_PATH, false) hashCollector += nonTypeSafeClassGen(FILE_PATH, true) hashCollector += nonTypeSafeClassGen(FILE_PATH, false) hashCollector += javaClassGen(FILE_PATH) @@ -57,8 +59,27 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { writeFile( FILE_PATH, + "package org.apache.mxnet", if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase", + "import org.apache.mxnet.annotation.Experimental", + generated) + } + + def typeSafeRandomClassGen(FILE_PATH: String, isSymbol: Boolean): String = { + val generated = typeSafeRandomFunctionsToGenerate(isSymbol) + .map { func => + val scalaDoc = generateAPIDocFromBackend(func) + val typeParameter = randomGenericTypeSpec(isSymbol, false) + val decl = generateAPISignature(func, isSymbol, typeParameter) + s"$scalaDoc\n$decl" + } + + writeFile( + FILE_PATH, "package org.apache.mxnet", + if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase", + """import org.apache.mxnet.annotation.Experimental + |import scala.reflect.ClassTag""".stripMargin, generated) } @@ -85,8 +106,9 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { writeFile( FILE_PATH, - if (isSymbol) "SymbolBase" else "NDArrayBase", "package org.apache.mxnet", + if (isSymbol) "SymbolBase" else "NDArrayBase", + "import org.apache.mxnet.annotation.Experimental", absFuncs) } @@ -110,7 +132,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { }).toSeq val packageName = "NDArrayBase" val packageDef = "package org.apache.mxnet.javaapi" - writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs) + writeFile( + filePath + "javaapi/", + packageDef, + packageName, + "import org.apache.mxnet.annotation.Experimental", + absFuncs) } def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = { @@ -146,7 +173,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { } } - def generateAPISignature(func: Func, isSymbol: Boolean): String = { + def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter: String = ""): String = { val argDef = ListBuffer[String]() argDef ++= typedFunctionCommonArgDef(func) @@ -162,7 +189,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { val returnType = func.returnType s"""@Experimental - |def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin + |def ${func.name}$typeParameter (${argDef.mkString(", ")}): $returnType""".stripMargin } def generateJavaAPISignature(func : Func) : String = { @@ -223,8 +250,8 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { } } - def writeFile(FILE_PATH: String, className: String, packageDef: String, - absFuncs: Seq[String]): String = { + def writeFile(FILE_PATH: String, packageDef: String, className: String, + imports: String, absFuncs: Seq[String]): String = { val finalStr = s"""/* @@ -246,7 +273,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { | |$packageDef | - |import org.apache.mxnet.annotation.Experimental + |$imports | |// scalastyle:off |abstract class $className { diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala index 9245ef1b437f..1c2c4fd704b3 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala @@ -23,7 +23,7 @@ import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils} import scala.collection.mutable.ListBuffer import scala.reflect.macros.blackbox -abstract class GeneratorBase { +private[mxnet] abstract class GeneratorBase { type Handle = Long case class Arg(argName: String, argType: String, argDesc: String, isOptional: Boolean) { @@ -46,7 +46,8 @@ abstract class GeneratorBase { } } - def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = { + // filter the operators to generate in the type-safe Symbol.api and NDArray.api + protected def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = { // Operators that should not be generated val notGenerated = Set("Custom") @@ -144,8 +145,8 @@ abstract class GeneratorBase { result } + // build function argument definition, with optionality, and safe names protected def typedFunctionCommonArgDef(func: Func): List[String] = { - // build function argument definition, with optionality, and safe names func.listOfArgs.map(arg => if (arg.isOptional) { // let's avoid a stupid Option[Array[...]] @@ -161,3 +162,71 @@ abstract class GeneratorBase { ) } } + +// a mixin to ease generating the Random module +private[mxnet] trait RandomHelpers { + self: GeneratorBase => + + // a generic type spec used in Symbol.random and NDArray.random modules + protected def randomGenericTypeSpec(isSymbol: Boolean, fullPackageSpec: Boolean): String = { + val classTag = if (fullPackageSpec) "scala.reflect.ClassTag" else "ClassTag" + if (isSymbol) s"[T: SymbolOrScalar : $classTag]" + else s"[T: NDArrayOrScalar : $classTag]" + } + + // filter the operators to generate in the type-safe Symbol.random and NDArray.random + protected def typeSafeRandomFunctionsToGenerate(isSymbol: Boolean): List[Func] = { + getBackEndFunctions(isSymbol) + .filter(f => f.name.startsWith("_sample_") || f.name.startsWith("_random_")) + .map(f => f.copy(name = f.name.stripPrefix("_"))) + // unify _random and _sample + .map(f => unifyRandom(f, isSymbol)) + // deduplicate + .groupBy(_.name) + .mapValues(_.head) + .values + .toList + } + + // unify call targets (random_xyz and sample_xyz) and unify their argument types + private def unifyRandom(func: Func, isSymbol: Boolean): Func = { + var typeConv = Set("org.apache.mxnet.NDArray", "org.apache.mxnet.Symbol", + "java.lang.Float", "java.lang.Integer") + + func.copy( + name = func.name.replaceAll("(random|sample)_", ""), + listOfArgs = func.listOfArgs + .map(hackNormalFunc) + .map(arg => + if (typeConv(arg.argType)) arg.copy(argType = "T") + else arg + ) + // TODO: some functions are non consistent in random_ vs sample_ regarding optionality + // we may try to unify that as well here. + ) + } + + // hacks to manage the fact that random_normal and sample_normal have + // non-consistent parameter naming in the back-end + // this first one, merge loc/scale and mu/sigma + protected def hackNormalFunc(arg: Arg): Arg = { + if (arg.argName == "loc") arg.copy(argName = "mu") + else if (arg.argName == "scale") arg.copy(argName = "sigma") + else arg + } + + // this second one reverts this merge prior to back-end call + protected def unhackNormalFunc(func: Func): String = { + if (func.name.equals("normal")) { + s"""if(target.equals("random_normal")) { + | if(map.contains("mu")) { map("loc") = map("mu"); map.remove("mu") } + | if(map.contains("sigma")) { map("scale") = map("sigma"); map.remove("sigma") } + |} + """.stripMargin + } else { + "" + } + + } + +} diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index d85abe1ecc4f..c18694b59bf6 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -18,7 +18,6 @@ package org.apache.mxnet import scala.annotation.StaticAnnotation -import scala.collection.mutable.ListBuffer import scala.language.experimental.macros import scala.reflect.macros.blackbox @@ -30,6 +29,14 @@ private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation private[mxnet] def macroTransform(annottees: Any*) = macro TypedNDArrayAPIMacro.typeSafeAPIDefs } +private[mxnet] class AddNDArrayRandomAPIs(isContrib: Boolean) extends StaticAnnotation { + private[mxnet] def macroTransform(annottees: Any*) = + macro TypedNDArrayRandomAPIMacro.typeSafeAPIDefs +} + +/** + * For non-typed NDArray API + */ private[mxnet] object NDArrayMacro extends GeneratorBase { def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { @@ -70,6 +77,9 @@ private[mxnet] object NDArrayMacro extends GeneratorBase { } } +/** + * NDArray.api code generation + */ private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase { def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { @@ -78,9 +88,9 @@ private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase { case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b)) } - val functions = typeSafeFunctionsToGenerate(isSymbol = false, isContrib) + val functionDefs = typeSafeFunctionsToGenerate(isSymbol = false, isContrib) + .map(f => buildTypedFunction(c)(f)) - val functionDefs = functions.map(f => buildTypedFunction(c)(f)) structGeneration(c)(functionDefs, annottees: _*) } @@ -89,49 +99,136 @@ private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase { import c.universe._ val returnType = "org.apache.mxnet.NDArrayFuncReturn" - val ndarrayType = "org.apache.mxnet.NDArray" - - // Construct argument field - val argDef = ListBuffer[String]() - argDef ++= typedFunctionCommonArgDef(function) - argDef += "out : Option[NDArray] = None" - - // Construct Implementation field - var impl = ListBuffer[String]() - impl += "val map = scala.collection.mutable.Map[String, Any]()" - impl += s"val args = scala.collection.mutable.ArrayBuffer.empty[$ndarrayType]" - - // NDArray arg implementation - impl ++= function.listOfArgs.map { arg => - if (arg.argType.equals(s"Array[$ndarrayType]")) { - s"args ++= ${arg.safeArgName}" - } else { - val base = - if (arg.argType.equals(ndarrayType)) { - // ndarrays go to args + + // Construct API arguments declaration + val argDecl = super.typedFunctionCommonArgDef(function) :+ "out : Option[NDArray] = None" + + // Map API input args to backend args + val backendArgsMapping = + function.listOfArgs.map { arg => + // ndarrays go to args, other types go to kwargs + if (arg.argType.equals(s"Array[org.apache.mxnet.NDArray]")) { + s"args ++= ${arg.safeArgName}.toSeq" + } else { + val base = if (arg.argType.equals("org.apache.mxnet.NDArray")) { s"args += ${arg.safeArgName}" } else { - // other types go to kwargs s"""map("${arg.argName}") = ${arg.safeArgName}""" } - if (arg.isOptional) s"if (!${arg.safeArgName}.isEmpty) $base.get" - else base + if (arg.isOptional) s"if (!${arg.safeArgName}.isEmpty) $base.get" + else base + } } - } - impl += - s"""if (!out.isEmpty) map("out") = out.get - |org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke( - | "${function.name}", args.toSeq, map.toMap) + val impl = + s""" + |def ${function.name} + | (${argDecl.mkString(",")}): $returnType = { + | + | val map = scala.collection.mutable.Map[String, Any]() + | val args = scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray] + | + | if (!out.isEmpty) map("out") = out.get + | + | ${backendArgsMapping.mkString("\n")} + | + | org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke( + | "${function.name}", args.toSeq, map.toMap) + |} """.stripMargin - // Combine and build the function string - val finalStr = - s"""def ${function.name} - | (${argDef.mkString(",")}) : $returnType - | = {${impl.mkString("\n")}} + c.parse(impl).asInstanceOf[DefDef] + } +} + + +/** + * NDArray.random code generation + */ +private[mxnet] object TypedNDArrayRandomAPIMacro extends GeneratorBase + with RandomHelpers { + + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { + // Note: no contrib managed in this module + + val functionDefs = typeSafeRandomFunctionsToGenerate(isSymbol = false) + .map(f => buildTypedFunction(c)(f)) + + structGeneration(c)(functionDefs, annottees: _*) + } + + protected def buildTypedFunction(c: blackbox.Context) + (function: Func): c.universe.DefDef = { + import c.universe._ + + val returnType = "org.apache.mxnet.NDArrayFuncReturn" + + // Construct API arguments declaration + val argDecl = super.typedFunctionCommonArgDef(function) :+ "out : Option[NDArray] = None" + + // Map API input args to backend args + val backendArgsMapping = + function.listOfArgs.map { arg => + // ndarrays go to args, other types go to kwargs + if (arg.argType.equals("Array[org.apache.mxnet.NDArray]")) { + s"args ++= ${arg.safeArgName}.toSeq" + } else { + if (arg.argType.equals("T")) { + if (arg.isOptional) { + s"""if(${arg.safeArgName}.isDefined) { + | if(isScalar) { + | map("${arg.argName}") = ${arg.safeArgName}.get + | } else { + | args += ${arg.safeArgName}.get.asInstanceOf[org.apache.mxnet.NDArray] + | } + |} + """.stripMargin + } else { + s"""if(isScalar) { + | map("${arg.argName}") = ${arg.safeArgName} + |} else { + | args += ${arg.safeArgName}.asInstanceOf[org.apache.mxnet.NDArray] + |} + """.stripMargin + } + } else { + if (arg.isOptional) { + s"""if (${arg.safeArgName}.isDefined) map("${arg.argName}")=${arg.safeArgName}.get""" + } else { + s"""map("${arg.argName}") = ${arg.safeArgName}""" + } + } + } + } + + val impl = + s""" + |def ${function.name}${randomGenericTypeSpec(false, true)} + | (${argDecl.mkString(",")}): $returnType = { + | + | val map = scala.collection.mutable.Map[String, Any]() + | val args = scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray] + | val isScalar = NDArrayOrScalar[T].isScalar + | + | if(out.isDefined) map("out") = out.get + | + | ${backendArgsMapping.mkString("\n")} + | + | val target = if(isScalar) { + | "random_${function.name}" + | } else { + | "sample_${function.name}" + | } + | + | ${unhackNormalFunc(function)} + | + | org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke( + | target, args.toSeq, map.toMap) + |} """.stripMargin - c.parse(finalStr).asInstanceOf[DefDef] + c.parse(impl).asInstanceOf[DefDef] } + + } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index ab864e1ef195..7ec80b9c066c 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -17,8 +17,8 @@ package org.apache.mxnet + import scala.annotation.StaticAnnotation -import scala.collection.mutable.ListBuffer import scala.language.experimental.macros import scala.reflect.macros.blackbox @@ -30,6 +30,14 @@ private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends StaticAnnotation private[mxnet] def macroTransform(annottees: Any*) = macro TypedSymbolAPIMacro.typeSafeAPIDefs } +private[mxnet] class AddSymbolRandomAPIs(isContrib: Boolean) extends StaticAnnotation { + private[mxnet] def macroTransform(annottees: Any*) = + macro TypedSymbolRandomAPIMacro.typeSafeAPIDefs +} + +/** + * For non-typed Symbol API + */ private[mxnet] object SymbolMacro extends GeneratorBase { def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { @@ -63,6 +71,9 @@ private[mxnet] object SymbolMacro extends GeneratorBase { } } +/** + * Symbol.api code generation + */ private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase { def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { @@ -71,9 +82,9 @@ private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase { case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b)) } - val functions = typeSafeFunctionsToGenerate(isSymbol = true, isContrib) + val functionDefs = typeSafeFunctionsToGenerate(isSymbol = true, isContrib) + .map(f => buildTypedFunction(c)(f)) - val functionDefs = functions.map(f => buildTypedFunction(c)(f)) structGeneration(c)(functionDefs, annottees: _*) } @@ -82,45 +93,111 @@ private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase { import c.universe._ val returnType = "org.apache.mxnet.Symbol" - val symbolType = "org.apache.mxnet.Symbol" - - // Construct argument field - val argDef = ListBuffer[String]() - argDef ++= typedFunctionCommonArgDef(function) - argDef += "name : String = null" - argDef += "attr : Map[String, String] = null" - - // Construct Implementation field - val impl = ListBuffer[String]() - impl += "val map = scala.collection.mutable.Map[String, Any]()" - impl += s"var args = scala.collection.Seq[$symbolType]()" - - // Symbol arg implementation - impl ++= function.listOfArgs.map { arg => - if (arg.argType.equals(s"Array[$symbolType]")) { - s"if (!${arg.safeArgName}.isEmpty) args = ${arg.safeArgName}.toSeq" - } else { - // all go in kwargs - if (arg.isOptional) { - s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") = ${arg.safeArgName}.get""" + + // Construct API arguments declaration + val argDecl = super.typedFunctionCommonArgDef(function) :+ + "name : String = null" :+ + "attr : Map[String, String] = null" + + // Map API input args to backend args + val backendArgsMapping = + function.listOfArgs.map { arg => + if (arg.argType.equals(s"Array[org.apache.mxnet.Symbol]")) { + s"args = ${arg.safeArgName}.toSeq" } else { - s"""map("${arg.argName}") = ${arg.safeArgName}""" + // all go in kwargs + if (arg.isOptional) { + s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") = ${arg.safeArgName}.get""" + } else { + s"""map("${arg.argName}") = ${arg.safeArgName}""" + } } } - } - impl += - s"""org.apache.mxnet.Symbol.createSymbolGeneral( - | "${function.name}", name, attr, args, map.toMap) + val impl = + s""" + |def ${function.name} + | (${argDecl.mkString(",")}): $returnType = { + | + | val map = scala.collection.mutable.Map[String, Any]() + | var args = scala.collection.Seq[org.apache.mxnet.Symbol]() + | + | ${backendArgsMapping.mkString("\n")} + | + | org.apache.mxnet.Symbol.createSymbolGeneral( + | "${function.name}", name, attr, args, map.toMap) + |} """.stripMargin - // Combine and build the function string - val finalStr = - s"""def ${function.name} - | (${argDef.mkString(",")}) : $returnType - | = {${impl.mkString("\n")}} + c.parse(impl).asInstanceOf[DefDef] + } +} + + +/** + * Symbol.random code generation + */ +private[mxnet] object TypedSymbolRandomAPIMacro extends GeneratorBase + with RandomHelpers { + + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { + val functionDefs = typeSafeRandomFunctionsToGenerate(isSymbol = true) + .map(f => buildTypedFunction(c)(f)) + + structGeneration(c)(functionDefs, annottees: _*) + } + + protected def buildTypedFunction(c: blackbox.Context) + (function: Func): c.universe.DefDef = { + import c.universe._ + + val returnType = "org.apache.mxnet.Symbol" + + // Construct API arguments declaration + val argDecl = super.typedFunctionCommonArgDef(function) :+ + "name : String = null" :+ + "attr : Map[String, String] = null" + + // Map API input args to backend args + val backendArgsMapping = + function.listOfArgs.map { arg => + if (arg.argType.equals(s"Array[org.apache.mxnet.Symbol]")) { + s"args = ${arg.safeArgName}.toSeq" + } else { + // all go in kwargs + if (arg.isOptional) { + s"""if (${arg.safeArgName}.isDefined) map("${arg.argName}") = ${arg.safeArgName}.get""" + } else { + s"""map("${arg.argName}") = ${arg.safeArgName}""" + } + } + } + + val impl = + s""" + |def ${function.name}${randomGenericTypeSpec(true, true)} + | (${argDecl.mkString(",")}): $returnType = { + | + | val map = scala.collection.mutable.Map[String, Any]() + | var args = scala.collection.Seq[org.apache.mxnet.Symbol]() + | val isScalar = SymbolOrScalar[T].isScalar + | + | ${backendArgsMapping.mkString("\n")} + | + | val target = if(isScalar) { + | "random_${function.name}" + | } else { + | "sample_${function.name}" + | } + | + | ${unhackNormalFunc(function)} + | + | org.apache.mxnet.Symbol.createSymbolGeneral( + | target, name, attr, args, map.toMap) + |} """.stripMargin - c.parse(finalStr).asInstanceOf[DefDef] + c.parse(impl).asInstanceOf[DefDef] } } +