From 704ab6030c9803ac70d2726618885cef323a4420 Mon Sep 17 00:00:00 2001 From: Mathieu DESPRIEE Date: Sun, 28 Oct 2018 20:21:46 +0100 Subject: [PATCH 1/9] introduce random API --- .../main/scala/org/apache/mxnet/Base.scala | 16 ++ .../main/scala/org/apache/mxnet/NDArray.scala | 1 + .../scala/org/apache/mxnet/NDArrayAPI.scala | 13 +- .../main/scala/org/apache/mxnet/Symbol.scala | 1 + .../scala/org/apache/mxnet/SymbolAPI.scala | 12 +- .../scala/org/apache/mxnet/NDArraySuite.scala | 17 ++ .../scala/org/apache/mxnet/SymbolSuite.scala | 22 +++ .../org/apache/mxnet/APIDocGenerator.scala | 34 +++- .../org/apache/mxnet/GeneratorBase.scala | 79 +++++++- .../scala/org/apache/mxnet/NDArrayMacro.scala | 175 ++++++++++++++---- .../scala/org/apache/mxnet/SymbolMacro.scala | 152 +++++++++++---- 11 files changed, 433 insertions(+), 89 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala index b2a53fd9f2dd..9a28e5a5927b 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala @@ -153,3 +153,19 @@ private[mxnet] object Base { } class MXNetError(val err: String) extends Exception(err) + +// used in the API Symbol.random for functions accepting multiple input types +class SymbolOrValue[T] +object SymbolOrValue { + implicit object FSymbolWitness extends SymbolOrValue[Float] + implicit object ISymbolWitness extends SymbolOrValue[Int] + implicit object SymbolWitness extends SymbolOrValue[Symbol] +} + +// used in the API NDArray.random for functions accepting multiple input types +class NDArrayOrValue[T] +object NDArrayOrValue { + implicit object FArrayWitness extends NDArrayOrValue[Float] + implicit object IArrayWitness extends NDArrayOrValue[Int] + implicit object ArrayWitness extends NDArrayOrValue[NDArray] +} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index f9f2dbe42a90..90f2073dd569 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -39,6 +39,7 @@ object NDArray extends NDArrayBase { private val functions: Map[String, NDArrayFunction] = initNDArrayModule() val api = NDArrayAPI + val random = NDArrayRandomAPI private def addDependency(froms: Array[NDArray], tos: Array[NDArray]): Unit = { froms.foreach { from => diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala index 1d8551c1b1e5..024fed1c4ba6 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala @@ -15,11 +15,22 @@ * limitations under the License. */ package org.apache.mxnet -@AddNDArrayAPIs(false) + /** * typesafe NDArray API: NDArray.api._ * Main code will be generated during compile time through Macros */ +@AddNDArrayAPIs(false) object NDArrayAPI extends NDArrayAPIBase { // TODO: Implement CustomOp for NDArray } + +/** + * typesafe NDArray random module: NDArray.random._ + * Main code will be generated during compile time through Macros + */ +@AddNDArrayRandomAPIs(false) +object NDArrayRandomAPI extends NDArrayRandomAPIBase { + +} + diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 4472a8426f9f..627b3bb6fe4c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -841,6 +841,7 @@ object Symbol extends SymbolBase { private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3) val api = SymbolAPI + val random = SymbolRandomAPI def pow(sym1: Symbol, sym2: Symbol): Symbol = { Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2)) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala index 1bfb0559cf96..f166de11ea52 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala @@ -19,11 +19,11 @@ package org.apache.mxnet import scala.collection.mutable -@AddSymbolAPIs(false) /** * typesafe Symbol API: Symbol.api._ * Main code will be generated during compile time through Macros */ +@AddSymbolAPIs(false) object SymbolAPI extends SymbolAPIBase { def Custom (op_type : String, kwargs : mutable.Map[String, Any], name : String = null, attr : Map[String, String] = null) : Symbol = { @@ -32,3 +32,13 @@ object SymbolAPI extends SymbolAPIBase { Symbol.createSymbolGeneral("Custom", name, attr, Seq(), map.toMap) } } + +/** + * typesafe Symbol random module: Symbol.random._ + * Main code will be generated during compile time through Macros + */ +@AddSymbolRandomAPIs(false) +object SymbolRandomAPI extends SymbolRandomAPIBase { + +} + diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index 5d88bb39e502..7992a0ed867b 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -576,4 +576,21 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { assert(arr.internal.toDoubleArray === Array(2d, 2d)) assert(arr.internal.toByteArray === Array(2.toByte, 2.toByte)) } + + test("NDArray random module is generated properly") { + val lam = NDArray.ones(1, 2) + val rnd = NDArray.random.poisson(lam = Some(lam), shape = Some(Shape(3, 4))) + val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3, 4))) + assert(rnd.shape === Shape(1, 2, 3, 4)) + assert(rnd2.shape === Shape(3, 4)) + } + + test("NDArray random module is generated properly - special case of 'normal'") { + val mu = NDArray.ones(1, 2) + val sigma = NDArray.ones(1, 2) * 2 + val rnd = NDArray.random.normal(mu = Some(mu), sigma = Some(sigma), shape = Some(Shape(3, 4))) + val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(3, 4))) + assert(rnd.shape === Shape(1, 2, 3, 4)) + assert(rnd2.shape === Shape(3, 4)) + } } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala index ebb61d7d4bfb..d134c83ff7e7 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala @@ -20,6 +20,7 @@ package org.apache.mxnet import org.scalatest.{BeforeAndAfterAll, FunSuite} class SymbolSuite extends FunSuite with BeforeAndAfterAll { + test("symbol compose") { val data = Symbol.Variable("data") @@ -71,4 +72,25 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll { val data2 = data.clone() assert(data.toJson === data2.toJson) } + + test("Symbol random module is generated properly") { + val lam = Symbol.Variable("lam") + val rnd = Symbol.random.poisson(lam = Some(lam), shape = Some(Shape(2, 2))) + val rnd2 = Symbol.random.poisson(lam = Some(1f), shape = Some(Shape(2, 2))) + // scalastyle:off println + println(s"Symbol.random.poisson debug info: ${rnd.debugStr}") + println(s"Symbol.random.poisson debug info: ${rnd2.debugStr}") + // scalastyle:on println + } + + test("Symbol random module is generated properly - special case of 'normal'") { + val loc = Symbol.Variable("loc") + val scale = Symbol.Variable("scale") + val rnd = Symbol.random.normal(mu = Some(loc), sigma = Some(scale), shape = Some(Shape(2, 2))) + val rnd2 = Symbol.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(2, 2))) + // scalastyle:off println + println(s"Symbol.random.sample_normal debug info: ${rnd.debugStr}") + println(s"Symbol.random.random_normal debug info: ${rnd2.debugStr}") + // scalastyle:on println + } } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index bfa378ea9e95..e022e533a214 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -27,13 +27,15 @@ import scala.collection.mutable.ListBuffer * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala * The code will be executed during Macros stage and file live in Core stage */ -private[mxnet] object APIDocGenerator extends GeneratorBase { +private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { def main(args: Array[String]): Unit = { val FILE_PATH = args(0) val hashCollector = ListBuffer[String]() hashCollector += typeSafeClassGen(FILE_PATH, true) hashCollector += typeSafeClassGen(FILE_PATH, false) + hashCollector += typeSafeRandomClassGen(FILE_PATH, true) + hashCollector += typeSafeRandomClassGen(FILE_PATH, false) hashCollector += nonTypeSafeClassGen(FILE_PATH, true) hashCollector += nonTypeSafeClassGen(FILE_PATH, false) val finalHash = hashCollector.mkString("\n") @@ -61,6 +63,22 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { generated) } + def typeSafeRandomClassGen(FILE_PATH: String, isSymbol: Boolean): String = { + val generated = typeSafeRandomFunctionsToGenerate(isSymbol) + .map { func => + val scalaDoc = generateAPIDocFromBackend(func) + val typeParameter = randomGenericTypeSpec(isSymbol) + val decl = generateAPISignature(func, isSymbol, typeParameter) + s"$scalaDoc\n$decl" + } + + writeFile( + FILE_PATH, + if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase", + "package org.apache.mxnet", + generated) + } + def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = { val absFuncs = functionsToGenerate(isSymbol, isContrib = false) .map { func => @@ -113,22 +131,22 @@ private[mxnet] object APIDocGenerator extends GeneratorBase { } } - def generateAPISignature(func: Func, isSymbol: Boolean): String = { - val argDef = ListBuffer[String]() + def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter: String = ""): String = { + val argDecl = ListBuffer[String]() - argDef ++= typedFunctionCommonArgDef(func) + argDecl ++= buildArgDecl(func) if (isSymbol) { - argDef += "name : String = null" - argDef += "attr : Map[String, String] = null" + argDecl += "name : String = null" + argDecl += "attr : Map[String, String] = null" } else { - argDef += "out : Option[NDArray] = None" + argDecl += "out : Option[NDArray] = None" } val returnType = func.returnType s"""@Experimental - |def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin + |def ${func.name}$typeParameter (${argDecl.mkString(", ")}): $returnType""".stripMargin } def writeFile(FILE_PATH: String, className: String, packageDef: String, diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala index f4c4a91bdf9a..29bc9503863e 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala @@ -23,7 +23,7 @@ import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils} import scala.collection.mutable.ListBuffer import scala.reflect.macros.blackbox -abstract class GeneratorBase { +private[mxnet] abstract class GeneratorBase { type Handle = Long case class Arg(argName: String, argType: String, argDesc: String, isOptional: Boolean) { @@ -36,7 +36,8 @@ abstract class GeneratorBase { case class Func(name: String, desc: String, listOfArgs: List[Arg], returnType: String) - def functionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = { + // filter the operators to generate (in the non type-safe apis) + protected def functionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = { val l = getBackEndFunctions(isSymbol) if (isContrib) { l.filter(func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) @@ -45,7 +46,8 @@ abstract class GeneratorBase { } } - def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = { + // filter the operators to generate in the type-safe Symbol.api and NDArray.api + protected def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = { // Operators that should not be generated val notGenerated = Set("Custom") @@ -138,8 +140,8 @@ abstract class GeneratorBase { result } - protected def typedFunctionCommonArgDef(func: Func): List[String] = { - // build function argument definition, with optionality, and safe names + // build function argument definition, with optionality, and safe names + protected def buildArgDecl(func: Func): List[String] = { func.listOfArgs.map(arg => if (arg.isOptional) { // let's avoid a stupid Option[Array[...]] @@ -155,3 +157,70 @@ abstract class GeneratorBase { ) } } + +// a mixin to ease generating the Random module +private[mxnet] trait RandomHelpers { + self: GeneratorBase => + + // a generic type spec used in Symbol.random and NDArray.random modules + protected def randomGenericTypeSpec(isSymbol: Boolean): String = { + if (isSymbol) "[T: SymbolOrValue : scala.reflect.runtime.universe.TypeTag]" + else "[T: NDArrayOrValue : scala.reflect.runtime.universe.TypeTag]" + } + + // filter the operators to generate in the type-safe Symbol.random and NDArray.random + protected def typeSafeRandomFunctionsToGenerate(isSymbol: Boolean): List[Func] = { + getBackEndFunctions(isSymbol) + .filter(f => f.name.startsWith("_sample_") || f.name.startsWith("_random_")) + .map(f => f.copy(name = f.name.stripPrefix("_"))) + // unify _random and _sample + .map(f => unifyRandom(f, isSymbol)) + // deduplicate + .groupBy(_.name) + .mapValues(_.head) + .values + .toList + } + + // unify call targets (random_xyz and sample_xyz) and unify their argument types + private def unifyRandom(func: Func, isSymbol: Boolean): Func = { + var typeConv = Set("org.apache.mxnet.NDArray", "org.apache.mxnet.Symbol", + "org.apache.mxnet.Base.MXFloat", "Int") + + func.copy( + name = func.name.replaceAll("(random|sample)_", ""), + listOfArgs = func.listOfArgs + .map(hackNormalFunc) + .map(arg => + if (typeConv(arg.argType)) arg.copy(argType = "T") + else arg + ) + // TODO: some functions are non consistent in random_ vs sample_ regarding optionality + // we may try to unify that as well here. + ) + } + + // hacks to manage the fact that random_normal and sample_normal have + // non-consistent parameter naming in the back-end + // this first one, merge loc/scale and mu/sigma + protected def hackNormalFunc(arg: Arg): Arg = { + if (arg.argName == "loc") arg.copy(argName = "mu") + else if (arg.argName == "scale") arg.copy(argName = "sigma") + else arg + } + + // this second one reverts this merge prior to back-end call + protected def unhackNormalFunc(func: Func): String = { + if (func.name.equals("normal")) { + s"""if(target.equals("random_normal")) { + | if(map.contains("mu")) { map("loc") = map("mu"); map.remove("mu") } + | if(map.contains("sigma")) { map("scale") = map("sigma"); map.remove("sigma") } + |} + """.stripMargin + } else { + "" + } + + } + +} diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index d85abe1ecc4f..6e39b87af1d1 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -18,7 +18,6 @@ package org.apache.mxnet import scala.annotation.StaticAnnotation -import scala.collection.mutable.ListBuffer import scala.language.experimental.macros import scala.reflect.macros.blackbox @@ -30,6 +29,14 @@ private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation private[mxnet] def macroTransform(annottees: Any*) = macro TypedNDArrayAPIMacro.typeSafeAPIDefs } +private[mxnet] class AddNDArrayRandomAPIs(isContrib: Boolean) extends StaticAnnotation { + private[mxnet] def macroTransform(annottees: Any*) = + macro TypedNDArrayRandomAPIMacro.typeSafeAPIDefs +} + +/** + * For non-typed NDArray API + */ private[mxnet] object NDArrayMacro extends GeneratorBase { def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { @@ -70,6 +77,9 @@ private[mxnet] object NDArrayMacro extends GeneratorBase { } } +/** + * NDArray.api code generation + */ private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase { def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { @@ -78,9 +88,9 @@ private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase { case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b)) } - val functions = typeSafeFunctionsToGenerate(isSymbol = false, isContrib) + val functionDefs = typeSafeFunctionsToGenerate(isSymbol = false, isContrib) + .map(f => buildTypedFunction(c)(f)) - val functionDefs = functions.map(f => buildTypedFunction(c)(f)) structGeneration(c)(functionDefs, annottees: _*) } @@ -88,50 +98,139 @@ private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase { (function: Func): c.universe.DefDef = { import c.universe._ - val returnType = "org.apache.mxnet.NDArrayFuncReturn" - val ndarrayType = "org.apache.mxnet.NDArray" - - // Construct argument field - val argDef = ListBuffer[String]() - argDef ++= typedFunctionCommonArgDef(function) - argDef += "out : Option[NDArray] = None" - - // Construct Implementation field - var impl = ListBuffer[String]() - impl += "val map = scala.collection.mutable.Map[String, Any]()" - impl += s"val args = scala.collection.mutable.ArrayBuffer.empty[$ndarrayType]" - - // NDArray arg implementation - impl ++= function.listOfArgs.map { arg => - if (arg.argType.equals(s"Array[$ndarrayType]")) { - s"args ++= ${arg.safeArgName}" - } else { - val base = - if (arg.argType.equals(ndarrayType)) { - // ndarrays go to args + val apiReturnType = "org.apache.mxnet.NDArrayFuncReturn" + + // Construct API arguments declaration + val argDecl = super.buildArgDecl(function) :+ "out : Option[NDArray] = None" + + // Map API input args to backend args + val backendArgsMapping = + function.listOfArgs.map { arg => + // ndarrays go to args, other types go to kwargs + if (arg.argType.equals(s"Array[org.apache.mxnet.NDArray]")) { + s"args ++= ${arg.safeArgName}.toSeq" + } else { + val base = if (arg.argType.equals("org.apache.mxnet.NDArray")) { s"args += ${arg.safeArgName}" } else { - // other types go to kwargs s"""map("${arg.argName}") = ${arg.safeArgName}""" } - if (arg.isOptional) s"if (!${arg.safeArgName}.isEmpty) $base.get" - else base + if (arg.isOptional) s"if (!${arg.safeArgName}.isEmpty) $base.get" + else base + } } - } - impl += - s"""if (!out.isEmpty) map("out") = out.get - |org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke( - | "${function.name}", args.toSeq, map.toMap) + val impl = + s""" + |def ${function.name} + | (${argDecl.mkString(",")}): $apiReturnType = { + | + | val map = scala.collection.mutable.Map[String, Any]() + | val args = scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray] + | + | if (!out.isEmpty) map("out") = out.get + | + | ${backendArgsMapping.mkString("\n")} + | + | org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke( + | "${function.name}", args.toSeq, map.toMap) + |} """.stripMargin - // Combine and build the function string - val finalStr = - s"""def ${function.name} - | (${argDef.mkString(",")}) : $returnType - | = {${impl.mkString("\n")}} + c.parse(impl).asInstanceOf[DefDef] + } +} + + +/** + * NDArray.random code generation + */ +private[mxnet] object TypedNDArrayRandomAPIMacro extends GeneratorBase + with RandomHelpers { + + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { + // Note: no contrib managed in this module + + val functionDefs = typeSafeRandomFunctionsToGenerate(isSymbol = false) + .map(f => buildTypedFunction(c)(f)) + + structGeneration(c)(functionDefs, annottees: _*) + } + + protected def buildTypedFunction(c: blackbox.Context) + (function: Func): c.universe.DefDef = { + import c.universe._ + + val apiReturnType = "org.apache.mxnet.NDArrayFuncReturn" + + // Construct API arguments declaration + val argDecl = super.buildArgDecl(function) :+ "out : Option[NDArray] = None" + + // Map API input args to backend args + val backendArgsMapping = + function.listOfArgs.map { arg => + // ndarrays go to args, other types go to kwargs + if (arg.argType.equals("Array[org.apache.mxnet.NDArray]")) { + s"args ++= ${arg.safeArgName}.toSeq" + } else { + if (arg.argType.equals("T")) { + if (arg.isOptional) { + s"""${arg.safeArgName} match { + | case None => + | case Some(a) if typeOf[T] =:= typeOf[org.apache.mxnet.NDArray] => + | args += a.asInstanceOf[org.apache.mxnet.NDArray] + | case Some(b) => map("${arg.argName}") = b + |} + """.stripMargin + } else { + s"""${arg.safeArgName} match { + | case a if typeOf[T] =:= typeOf[org.apache.mxnet.NDArray] => + | args += a.asInstanceOf[org.apache.mxnet.NDArray] + | case b => map("${arg.argName}") = b + |} + """.stripMargin + } + } else { + if (arg.isOptional) { + s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") = ${arg.safeArgName}.get""" + } else { + s"""map("${arg.argName}") = ${arg.safeArgName}""" + } + } + } + } + + // since the API is mixing calls using NDArrays or Float through template (see unifyRandom), + // to determine the target call, we pick the first arg that is using the template type + val firstArg = function.listOfArgs.filter(arg => arg.argType == "T").head + + val impl = + s""" + |def ${function.name}${randomGenericTypeSpec(false)} + | (${argDecl.mkString(",")}): $apiReturnType = { + | + | import scala.reflect.runtime.universe.typeOf + | val map = scala.collection.mutable.Map[String, Any]() + | val args = scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray] + | + | if (!out.isEmpty) map("out") = out.get + | + | ${backendArgsMapping.mkString("\n")} + | + | val target = ${firstArg.safeArgName} match { + | case _ if typeOf[T] =:= typeOf[org.apache.mxnet.NDArray] => "sample_${function.name}" + | case _ => "random_${function.name}" + | } + | + | ${unhackNormalFunc(function)} + | + | org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke( + | target, args.toSeq, map.toMap) + |} """.stripMargin - c.parse(finalStr).asInstanceOf[DefDef] + c.parse(impl).asInstanceOf[DefDef] } + + } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index ab864e1ef195..6878b26630f7 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -17,8 +17,8 @@ package org.apache.mxnet + import scala.annotation.StaticAnnotation -import scala.collection.mutable.ListBuffer import scala.language.experimental.macros import scala.reflect.macros.blackbox @@ -30,6 +30,14 @@ private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends StaticAnnotation private[mxnet] def macroTransform(annottees: Any*) = macro TypedSymbolAPIMacro.typeSafeAPIDefs } +private[mxnet] class AddSymbolRandomAPIs(isContrib: Boolean) extends StaticAnnotation { + private[mxnet] def macroTransform(annottees: Any*) = + macro TypedSymbolRandomAPIMacro.typeSafeAPIDefs +} + +/** + * For non-typed Symbol API + */ private[mxnet] object SymbolMacro extends GeneratorBase { def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { @@ -63,6 +71,9 @@ private[mxnet] object SymbolMacro extends GeneratorBase { } } +/** + * Symbol.api code generation + */ private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase { def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { @@ -71,9 +82,9 @@ private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase { case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b)) } - val functions = typeSafeFunctionsToGenerate(isSymbol = true, isContrib) + val functionDefs = typeSafeFunctionsToGenerate(isSymbol = true, isContrib) + .map(f => buildTypedFunction(c)(f)) - val functionDefs = functions.map(f => buildTypedFunction(c)(f)) structGeneration(c)(functionDefs, annottees: _*) } @@ -81,46 +92,115 @@ private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase { (function: Func): c.universe.DefDef = { import c.universe._ - val returnType = "org.apache.mxnet.Symbol" - val symbolType = "org.apache.mxnet.Symbol" - - // Construct argument field - val argDef = ListBuffer[String]() - argDef ++= typedFunctionCommonArgDef(function) - argDef += "name : String = null" - argDef += "attr : Map[String, String] = null" - - // Construct Implementation field - val impl = ListBuffer[String]() - impl += "val map = scala.collection.mutable.Map[String, Any]()" - impl += s"var args = scala.collection.Seq[$symbolType]()" - - // Symbol arg implementation - impl ++= function.listOfArgs.map { arg => - if (arg.argType.equals(s"Array[$symbolType]")) { - s"if (!${arg.safeArgName}.isEmpty) args = ${arg.safeArgName}.toSeq" - } else { - // all go in kwargs - if (arg.isOptional) { - s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") = ${arg.safeArgName}.get""" + val apiReturnType = "org.apache.mxnet.Symbol" + + // Construct API arguments declaration + val argDecl = super.buildArgDecl(function) :+ + "name : String = null" :+ + "attr : Map[String, String] = null" + + // Map API input args to backend args + val backendArgsMapping = + function.listOfArgs.map { arg => + if (arg.argType.equals(s"Array[org.apache.mxnet.Symbol]")) { + s"args = ${arg.safeArgName}.toSeq" } else { - s"""map("${arg.argName}") = ${arg.safeArgName}""" + // all go in kwargs + if (arg.isOptional) { + s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") = ${arg.safeArgName}.get""" + } else { + s"""map("${arg.argName}") = ${arg.safeArgName}""" + } } } - } - impl += - s"""org.apache.mxnet.Symbol.createSymbolGeneral( - | "${function.name}", name, attr, args, map.toMap) + val impl = + s""" + |def ${function.name} + | (${argDecl.mkString(",")}): $apiReturnType = { + | + | val map = scala.collection.mutable.Map[String, Any]() + | var args = scala.collection.Seq[org.apache.mxnet.Symbol]() + | + | ${backendArgsMapping.mkString("\n")} + | + | org.apache.mxnet.Symbol.createSymbolGeneral( + | "${function.name}", name, attr, args, map.toMap) + |} """.stripMargin - // Combine and build the function string - val finalStr = - s"""def ${function.name} - | (${argDef.mkString(",")}) : $returnType - | = {${impl.mkString("\n")}} + c.parse(impl).asInstanceOf[DefDef] + } +} + + +/** + * Symbol.random code generation + */ +private[mxnet] object TypedSymbolRandomAPIMacro extends GeneratorBase + with RandomHelpers { + + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { + val functionDefs = typeSafeRandomFunctionsToGenerate(isSymbol = true) + .map(f => buildTypedFunction(c)(f)) + + structGeneration(c)(functionDefs, annottees: _*) + } + + protected def buildTypedFunction(c: blackbox.Context) + (function: Func): c.universe.DefDef = { + import c.universe._ + + val apiReturnType = "org.apache.mxnet.Symbol" + + // Construct API arguments declaration + val argDecl = super.buildArgDecl(function) :+ + "name : String = null" :+ + "attr : Map[String, String] = null" + + // Map API input args to backend args + val backendArgsMapping = + function.listOfArgs.map { arg => + if (arg.argType.equals(s"Array[org.apache.mxnet.Symbol]")) { + s"args = ${arg.safeArgName}.toSeq" + } else { + // all go in kwargs + if (arg.isOptional) { + s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") = ${arg.safeArgName}.get""" + } else { + s"""map("${arg.argName}") = ${arg.safeArgName}""" + } + } + } + + // since the API is mixing calls using Symbol or Float through template (see unifyRandom), + // to determine the target call, we pick the first arg that is using the template type + val firstArg = function.listOfArgs.filter(arg => arg.argType == "T").head + + val impl = + s""" + |def ${function.name}${randomGenericTypeSpec(true)} + | (${argDecl.mkString(",")}): $apiReturnType = { + | + | import scala.reflect.runtime.universe.typeOf + | val map = scala.collection.mutable.Map[String, Any]() + | var args = scala.collection.Seq[org.apache.mxnet.Symbol]() + | + | ${backendArgsMapping.mkString("\n")} + | + | val target = ${firstArg.safeArgName} match { + | case _ if typeOf[T] =:= typeOf[org.apache.mxnet.Symbol] => "sample_${function.name}" + | case _ => "random_${function.name}" + | } + | + | ${unhackNormalFunc(function)} + | + | org.apache.mxnet.Symbol.createSymbolGeneral( + | target, name, attr, args, map.toMap) + |} """.stripMargin - c.parse(finalStr).asInstanceOf[DefDef] + c.parse(impl).asInstanceOf[DefDef] } } + From c6cb9ff72a20ddebb14b063426be66213861472d Mon Sep 17 00:00:00 2001 From: Mathieu DESPRIEE Date: Mon, 5 Nov 2018 22:50:53 +0100 Subject: [PATCH 2/9] revert useless changes --- .../scala/org/apache/mxnet/APIDocGenerator.scala | 12 ++++++------ .../main/scala/org/apache/mxnet/GeneratorBase.scala | 2 +- .../main/scala/org/apache/mxnet/NDArrayMacro.scala | 12 ++++++------ .../main/scala/org/apache/mxnet/SymbolMacro.scala | 12 ++++++------ 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index e022e533a214..151d649e6b87 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -132,21 +132,21 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { } def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter: String = ""): String = { - val argDecl = ListBuffer[String]() + val argDef = ListBuffer[String]() - argDecl ++= buildArgDecl(func) + argDef ++= typedFunctionCommonArgDef(func) if (isSymbol) { - argDecl += "name : String = null" - argDecl += "attr : Map[String, String] = null" + argDef += "name : String = null" + argDef += "attr : Map[String, String] = null" } else { - argDecl += "out : Option[NDArray] = None" + argDef += "out : Option[NDArray] = None" } val returnType = func.returnType s"""@Experimental - |def ${func.name}$typeParameter (${argDecl.mkString(", ")}): $returnType""".stripMargin + |def ${func.name}$typeParameter (${argDef.mkString(", ")}): $returnType""".stripMargin } def writeFile(FILE_PATH: String, className: String, packageDef: String, diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala index 29bc9503863e..8eb5e0638eec 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala @@ -141,7 +141,7 @@ private[mxnet] abstract class GeneratorBase { } // build function argument definition, with optionality, and safe names - protected def buildArgDecl(func: Func): List[String] = { + protected def typedFunctionCommonArgDef(func: Func): List[String] = { func.listOfArgs.map(arg => if (arg.isOptional) { // let's avoid a stupid Option[Array[...]] diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index 6e39b87af1d1..6fa5143a94c7 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -98,10 +98,10 @@ private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase { (function: Func): c.universe.DefDef = { import c.universe._ - val apiReturnType = "org.apache.mxnet.NDArrayFuncReturn" + val returnType = "org.apache.mxnet.NDArrayFuncReturn" // Construct API arguments declaration - val argDecl = super.buildArgDecl(function) :+ "out : Option[NDArray] = None" + val argDecl = super.typedFunctionCommonArgDef(function) :+ "out : Option[NDArray] = None" // Map API input args to backend args val backendArgsMapping = @@ -123,7 +123,7 @@ private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase { val impl = s""" |def ${function.name} - | (${argDecl.mkString(",")}): $apiReturnType = { + | (${argDecl.mkString(",")}): $returnType = { | | val map = scala.collection.mutable.Map[String, Any]() | val args = scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray] @@ -161,10 +161,10 @@ private[mxnet] object TypedNDArrayRandomAPIMacro extends GeneratorBase (function: Func): c.universe.DefDef = { import c.universe._ - val apiReturnType = "org.apache.mxnet.NDArrayFuncReturn" + val returnType = "org.apache.mxnet.NDArrayFuncReturn" // Construct API arguments declaration - val argDecl = super.buildArgDecl(function) :+ "out : Option[NDArray] = None" + val argDecl = super.typedFunctionCommonArgDef(function) :+ "out : Option[NDArray] = None" // Map API input args to backend args val backendArgsMapping = @@ -207,7 +207,7 @@ private[mxnet] object TypedNDArrayRandomAPIMacro extends GeneratorBase val impl = s""" |def ${function.name}${randomGenericTypeSpec(false)} - | (${argDecl.mkString(",")}): $apiReturnType = { + | (${argDecl.mkString(",")}): $returnType = { | | import scala.reflect.runtime.universe.typeOf | val map = scala.collection.mutable.Map[String, Any]() diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 6878b26630f7..56457905da61 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -92,10 +92,10 @@ private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase { (function: Func): c.universe.DefDef = { import c.universe._ - val apiReturnType = "org.apache.mxnet.Symbol" + val returnType = "org.apache.mxnet.Symbol" // Construct API arguments declaration - val argDecl = super.buildArgDecl(function) :+ + val argDecl = super.typedFunctionCommonArgDef(function) :+ "name : String = null" :+ "attr : Map[String, String] = null" @@ -117,7 +117,7 @@ private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase { val impl = s""" |def ${function.name} - | (${argDecl.mkString(",")}): $apiReturnType = { + | (${argDecl.mkString(",")}): $returnType = { | | val map = scala.collection.mutable.Map[String, Any]() | var args = scala.collection.Seq[org.apache.mxnet.Symbol]() @@ -151,10 +151,10 @@ private[mxnet] object TypedSymbolRandomAPIMacro extends GeneratorBase (function: Func): c.universe.DefDef = { import c.universe._ - val apiReturnType = "org.apache.mxnet.Symbol" + val returnType = "org.apache.mxnet.Symbol" // Construct API arguments declaration - val argDecl = super.buildArgDecl(function) :+ + val argDecl = super.typedFunctionCommonArgDef(function) :+ "name : String = null" :+ "attr : Map[String, String] = null" @@ -180,7 +180,7 @@ private[mxnet] object TypedSymbolRandomAPIMacro extends GeneratorBase val impl = s""" |def ${function.name}${randomGenericTypeSpec(true)} - | (${argDecl.mkString(",")}): $apiReturnType = { + | (${argDecl.mkString(",")}): $returnType = { | | import scala.reflect.runtime.universe.typeOf | val map = scala.collection.mutable.Map[String, Any]() From 5fdcd6265891b336c938787e3c117f32202f711f Mon Sep 17 00:00:00 2001 From: Mathieu DESPRIEE Date: Mon, 5 Nov 2018 23:17:48 +0100 Subject: [PATCH 3/9] shorter types in APIDoc gen code --- .../org/apache/mxnet/APIDocGenerator.scala | 20 +++++++++++-------- .../org/apache/mxnet/GeneratorBase.scala | 7 ++++--- .../scala/org/apache/mxnet/NDArrayMacro.scala | 2 +- .../scala/org/apache/mxnet/SymbolMacro.scala | 2 +- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 151d649e6b87..22fea05aa965 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -58,8 +58,9 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { writeFile( FILE_PATH, - if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase", "package org.apache.mxnet", + if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase", + "import org.apache.mxnet.annotation.Experimental", generated) } @@ -67,15 +68,17 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { val generated = typeSafeRandomFunctionsToGenerate(isSymbol) .map { func => val scalaDoc = generateAPIDocFromBackend(func) - val typeParameter = randomGenericTypeSpec(isSymbol) + val typeParameter = randomGenericTypeSpec(isSymbol, false) val decl = generateAPISignature(func, isSymbol, typeParameter) s"$scalaDoc\n$decl" } writeFile( FILE_PATH, - if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase", "package org.apache.mxnet", + if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase", + """import org.apache.mxnet.annotation.Experimental + |import scala.reflect.runtime.universe.TypeTag""".stripMargin, generated) } @@ -102,14 +105,15 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { writeFile( FILE_PATH, - if (isSymbol) "SymbolBase" else "NDArrayBase", "package org.apache.mxnet", + if (isSymbol) "SymbolBase" else "NDArrayBase", + "import org.apache.mxnet.annotation.Experimental", absFuncs) } def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = { val desc = func.desc.split("\n") - .mkString(" *
\n", "\n  * ", "  * 
\n") + .mkString(" *
", "\n  * ", "\n  * 
") val params = func.listOfArgs.map { absClassArg => s" * @param ${absClassArg.safeArgName}\t\t${absClassArg.argDesc}" @@ -149,8 +153,8 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { |def ${func.name}$typeParameter (${argDef.mkString(", ")}): $returnType""".stripMargin } - def writeFile(FILE_PATH: String, className: String, packageDef: String, - absFuncs: Seq[String]): String = { + def writeFile(FILE_PATH: String, packageDef: String, className: String, + imports: String, absFuncs: Seq[String]): String = { val finalStr = s"""/* @@ -172,7 +176,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { | |$packageDef | - |import org.apache.mxnet.annotation.Experimental + |$imports | |// scalastyle:off |abstract class $className { diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala index 8eb5e0638eec..bcef452655b3 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala @@ -163,9 +163,10 @@ private[mxnet] trait RandomHelpers { self: GeneratorBase => // a generic type spec used in Symbol.random and NDArray.random modules - protected def randomGenericTypeSpec(isSymbol: Boolean): String = { - if (isSymbol) "[T: SymbolOrValue : scala.reflect.runtime.universe.TypeTag]" - else "[T: NDArrayOrValue : scala.reflect.runtime.universe.TypeTag]" + protected def randomGenericTypeSpec(isSymbol: Boolean, fullPackageSpec: Boolean): String = { + val typeTag = if (fullPackageSpec) "scala.reflect.runtime.universe.TypeTag" else "TypeTag" + if (isSymbol) s"[T: SymbolOrValue : $typeTag]" + else s"[T: NDArrayOrValue : $typeTag]" } // filter the operators to generate in the type-safe Symbol.random and NDArray.random diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index 6fa5143a94c7..ebf45b07392c 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -206,7 +206,7 @@ private[mxnet] object TypedNDArrayRandomAPIMacro extends GeneratorBase val impl = s""" - |def ${function.name}${randomGenericTypeSpec(false)} + |def ${function.name}${randomGenericTypeSpec(false, true)} | (${argDecl.mkString(",")}): $returnType = { | | import scala.reflect.runtime.universe.typeOf diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 56457905da61..05c3e0f561e5 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -179,7 +179,7 @@ private[mxnet] object TypedSymbolRandomAPIMacro extends GeneratorBase val impl = s""" - |def ${function.name}${randomGenericTypeSpec(true)} + |def ${function.name}${randomGenericTypeSpec(true, true)} | (${argDecl.mkString(",")}): $returnType = { | | import scala.reflect.runtime.universe.typeOf From 526f78cb6ed9cbf7df9c60bf6b3778070db09572 Mon Sep 17 00:00:00 2001 From: Mathieu DESPRIEE Date: Tue, 27 Nov 2018 16:44:28 +0100 Subject: [PATCH 4/9] fix after merge from master --- .../src/main/scala/org/apache/mxnet/APIDocGenerator.scala | 7 ++++++- .../src/main/scala/org/apache/mxnet/GeneratorBase.scala | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 94375cc822ea..3f3c2a1f5512 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -132,7 +132,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { }).toSeq val packageName = "NDArrayBase" val packageDef = "package org.apache.mxnet.javaapi" - writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs) + writeFile( + filePath + "javaapi/", + packageDef, + packageName, + "import org.apache.mxnet.annotation.Experimental", + absFuncs) } def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = { diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala index d04ff1d378c1..8c8331f6670f 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala @@ -191,7 +191,7 @@ private[mxnet] trait RandomHelpers { // unify call targets (random_xyz and sample_xyz) and unify their argument types private def unifyRandom(func: Func, isSymbol: Boolean): Func = { var typeConv = Set("org.apache.mxnet.NDArray", "org.apache.mxnet.Symbol", - "org.apache.mxnet.Base.MXFloat", "Int") + "java.lang.Float", "java.lang.Integer") func.copy( name = func.name.replaceAll("(random|sample)_", ""), From d9695c5410feda4f2822ddc319aeec5dbbf216cf Mon Sep 17 00:00:00 2001 From: Mathieu DESPRIEE Date: Wed, 5 Dec 2018 10:02:41 +0100 Subject: [PATCH 5/9] Trigger CI From d4cb629ee1e442a56fd13c23288c42b380658312 Mon Sep 17 00:00:00 2001 From: Mathieu DESPRIEE Date: Tue, 11 Dec 2018 11:36:39 +0100 Subject: [PATCH 6/9] temp code / diag on CI --- .../src/main/scala/org/apache/mxnet/SymbolMacro.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 05c3e0f561e5..d489c0913931 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -173,6 +173,14 @@ private[mxnet] object TypedSymbolRandomAPIMacro extends GeneratorBase } } + + // TEMPORARY CODE (diag problem on CI) + if(function.listOfArgs.filter(arg => arg.argType == "T").isEmpty) { + // scalastyle:off println + println(s"Func: ${function.name}, args=[${function.listOfArgs.mkString(",")}]") + } + + // since the API is mixing calls using Symbol or Float through template (see unifyRandom), // to determine the target call, we pick the first arg that is using the template type val firstArg = function.listOfArgs.filter(arg => arg.argType == "T").head From c6af7931a358d025a42ba7b8ffea50a5f372a005 Mon Sep 17 00:00:00 2001 From: Mathieu DESPRIEE Date: Wed, 12 Dec 2018 08:41:13 +0100 Subject: [PATCH 7/9] cleanup type-class code --- .../main/scala/org/apache/mxnet/Base.scala | 24 ++++++------- .../org/apache/mxnet/APIDocGenerator.scala | 2 +- .../org/apache/mxnet/GeneratorBase.scala | 6 ++-- .../scala/org/apache/mxnet/NDArrayMacro.scala | 36 +++++++++---------- .../scala/org/apache/mxnet/SymbolMacro.scala | 23 ++++-------- 5 files changed, 39 insertions(+), 52 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala index 9a28e5a5927b..8924a3970db3 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala @@ -154,18 +154,18 @@ private[mxnet] object Base { class MXNetError(val err: String) extends Exception(err) -// used in the API Symbol.random for functions accepting multiple input types -class SymbolOrValue[T] -object SymbolOrValue { - implicit object FSymbolWitness extends SymbolOrValue[Float] - implicit object ISymbolWitness extends SymbolOrValue[Int] - implicit object SymbolWitness extends SymbolOrValue[Symbol] +// Some type-classes to ease the work in Symbol.random and NDArray.random modules + +class SymbolOrScalar[T](val isScalar: Boolean) +object SymbolOrScalar { + implicit object FloatWitness extends SymbolOrScalar[Float](true) + implicit object IntWitness extends SymbolOrScalar[Int](true) + implicit object SymbolWitness extends SymbolOrScalar[Symbol](false) } -// used in the API NDArray.random for functions accepting multiple input types -class NDArrayOrValue[T] -object NDArrayOrValue { - implicit object FArrayWitness extends NDArrayOrValue[Float] - implicit object IArrayWitness extends NDArrayOrValue[Int] - implicit object ArrayWitness extends NDArrayOrValue[NDArray] +class NDArrayOrScalar[T](val isScalar: Boolean) +object NDArrayOrScalar { + implicit object FloatWitness extends NDArrayOrScalar[Float](true) + implicit object IntWitness extends NDArrayOrScalar[Int](true) + implicit object NDArrayWitness extends NDArrayOrScalar[NDArray](false) } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 3f3c2a1f5512..97cd18a5b337 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -79,7 +79,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { "package org.apache.mxnet", if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase", """import org.apache.mxnet.annotation.Experimental - |import scala.reflect.runtime.universe.TypeTag""".stripMargin, + |import scala.reflect.ClassTag""".stripMargin, generated) } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala index 8c8331f6670f..1c2c4fd704b3 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala @@ -169,9 +169,9 @@ private[mxnet] trait RandomHelpers { // a generic type spec used in Symbol.random and NDArray.random modules protected def randomGenericTypeSpec(isSymbol: Boolean, fullPackageSpec: Boolean): String = { - val typeTag = if (fullPackageSpec) "scala.reflect.runtime.universe.TypeTag" else "TypeTag" - if (isSymbol) s"[T: SymbolOrValue : $typeTag]" - else s"[T: NDArrayOrValue : $typeTag]" + val classTag = if (fullPackageSpec) "scala.reflect.ClassTag" else "ClassTag" + if (isSymbol) s"[T: SymbolOrScalar : $classTag]" + else s"[T: NDArrayOrScalar : $classTag]" } // filter the operators to generate in the type-safe Symbol.random and NDArray.random diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index ebf45b07392c..842251569c4e 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -175,24 +175,25 @@ private[mxnet] object TypedNDArrayRandomAPIMacro extends GeneratorBase } else { if (arg.argType.equals("T")) { if (arg.isOptional) { - s"""${arg.safeArgName} match { - | case None => - | case Some(a) if typeOf[T] =:= typeOf[org.apache.mxnet.NDArray] => - | args += a.asInstanceOf[org.apache.mxnet.NDArray] - | case Some(b) => map("${arg.argName}") = b + s"""if(${arg.safeArgName}.isDefined) { + | if(isScalar) { + | map("${arg.argName}") = ${arg.safeArgName}.get + | } else { + | args += ${arg.safeArgName}.get.asInstanceOf[org.apache.mxnet.NDArray] + | } |} """.stripMargin } else { - s"""${arg.safeArgName} match { - | case a if typeOf[T] =:= typeOf[org.apache.mxnet.NDArray] => - | args += a.asInstanceOf[org.apache.mxnet.NDArray] - | case b => map("${arg.argName}") = b + s"""if(isScalar) { + | map("${arg.argName}") = ${arg.safeArgName} + |} else { + | args += ${arg.safeArgName}.asInstanceOf[org.apache.mxnet.NDArray] |} """.stripMargin } } else { if (arg.isOptional) { - s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") = ${arg.safeArgName}.get""" + s"""if (${arg.safeArgName}.isDefined) map("${arg.argName}") = ${arg.safeArgName}.get""" } else { s"""map("${arg.argName}") = ${arg.safeArgName}""" } @@ -200,26 +201,23 @@ private[mxnet] object TypedNDArrayRandomAPIMacro extends GeneratorBase } } - // since the API is mixing calls using NDArrays or Float through template (see unifyRandom), - // to determine the target call, we pick the first arg that is using the template type - val firstArg = function.listOfArgs.filter(arg => arg.argType == "T").head - val impl = s""" |def ${function.name}${randomGenericTypeSpec(false, true)} | (${argDecl.mkString(",")}): $returnType = { | - | import scala.reflect.runtime.universe.typeOf | val map = scala.collection.mutable.Map[String, Any]() | val args = scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray] + | val isScalar = implicitly[NDArrayOrScalar[T]].isScalar | - | if (!out.isEmpty) map("out") = out.get + | if(out.isDefined) map("out") = out.get | | ${backendArgsMapping.mkString("\n")} | - | val target = ${firstArg.safeArgName} match { - | case _ if typeOf[T] =:= typeOf[org.apache.mxnet.NDArray] => "sample_${function.name}" - | case _ => "random_${function.name}" + | val target = if(isScalar) { + | "random_${function.name}" + | } else { + | "sample_${function.name}" | } | | ${unhackNormalFunc(function)} diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index d489c0913931..a50b1cd1f2b5 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -166,39 +166,28 @@ private[mxnet] object TypedSymbolRandomAPIMacro extends GeneratorBase } else { // all go in kwargs if (arg.isOptional) { - s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") = ${arg.safeArgName}.get""" + s"""if (${arg.safeArgName}.isDefined) map("${arg.argName}") = ${arg.safeArgName}.get""" } else { s"""map("${arg.argName}") = ${arg.safeArgName}""" } } } - - // TEMPORARY CODE (diag problem on CI) - if(function.listOfArgs.filter(arg => arg.argType == "T").isEmpty) { - // scalastyle:off println - println(s"Func: ${function.name}, args=[${function.listOfArgs.mkString(",")}]") - } - - - // since the API is mixing calls using Symbol or Float through template (see unifyRandom), - // to determine the target call, we pick the first arg that is using the template type - val firstArg = function.listOfArgs.filter(arg => arg.argType == "T").head - val impl = s""" |def ${function.name}${randomGenericTypeSpec(true, true)} | (${argDecl.mkString(",")}): $returnType = { | - | import scala.reflect.runtime.universe.typeOf | val map = scala.collection.mutable.Map[String, Any]() | var args = scala.collection.Seq[org.apache.mxnet.Symbol]() + | val isScalar = implicitly[SymbolOrScalar[T]].isScalar | | ${backendArgsMapping.mkString("\n")} | - | val target = ${firstArg.safeArgName} match { - | case _ if typeOf[T] =:= typeOf[org.apache.mxnet.Symbol] => "sample_${function.name}" - | case _ => "random_${function.name}" + | val target = if(isScalar) { + | "random_${function.name}" + | } else { + | "sample_${function.name}" | } | | ${unhackNormalFunc(function)} From 1f7d0feded55a8c7fdb3120ceb67b913fe82c811 Mon Sep 17 00:00:00 2001 From: Mathieu DESPRIEE Date: Wed, 12 Dec 2018 09:07:20 +0100 Subject: [PATCH 8/9] cleanup type-class code --- scala-package/core/src/main/scala/org/apache/mxnet/Base.scala | 2 ++ .../macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala | 2 +- .../macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala index 8924a3970db3..bb9518d51f1e 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala @@ -158,6 +158,7 @@ class MXNetError(val err: String) extends Exception(err) class SymbolOrScalar[T](val isScalar: Boolean) object SymbolOrScalar { + def apply[T](implicit ev: SymbolOrScalar[T]): SymbolOrScalar[T] = ev implicit object FloatWitness extends SymbolOrScalar[Float](true) implicit object IntWitness extends SymbolOrScalar[Int](true) implicit object SymbolWitness extends SymbolOrScalar[Symbol](false) @@ -165,6 +166,7 @@ object SymbolOrScalar { class NDArrayOrScalar[T](val isScalar: Boolean) object NDArrayOrScalar { + def apply[T](implicit ev: NDArrayOrScalar[T]): NDArrayOrScalar[T] = ev implicit object FloatWitness extends NDArrayOrScalar[Float](true) implicit object IntWitness extends NDArrayOrScalar[Int](true) implicit object NDArrayWitness extends NDArrayOrScalar[NDArray](false) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index 842251569c4e..230e82539f7b 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -208,7 +208,7 @@ private[mxnet] object TypedNDArrayRandomAPIMacro extends GeneratorBase | | val map = scala.collection.mutable.Map[String, Any]() | val args = scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray] - | val isScalar = implicitly[NDArrayOrScalar[T]].isScalar + | val isScalar = NDArrayOrScalar[T].isScalar | | if(out.isDefined) map("out") = out.get | diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index a50b1cd1f2b5..7ec80b9c066c 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -180,7 +180,7 @@ private[mxnet] object TypedSymbolRandomAPIMacro extends GeneratorBase | | val map = scala.collection.mutable.Map[String, Any]() | var args = scala.collection.Seq[org.apache.mxnet.Symbol]() - | val isScalar = implicitly[SymbolOrScalar[T]].isScalar + | val isScalar = SymbolOrScalar[T].isScalar | | ${backendArgsMapping.mkString("\n")} | From a87c1135eb15db3c5d5465c63c2ec8bb890b7707 Mon Sep 17 00:00:00 2001 From: Mathieu DESPRIEE Date: Wed, 12 Dec 2018 14:14:30 +0100 Subject: [PATCH 9/9] fix scalastyle --- .../macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index 230e82539f7b..c18694b59bf6 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -193,7 +193,7 @@ private[mxnet] object TypedNDArrayRandomAPIMacro extends GeneratorBase } } else { if (arg.isOptional) { - s"""if (${arg.safeArgName}.isDefined) map("${arg.argName}") = ${arg.safeArgName}.get""" + s"""if (${arg.safeArgName}.isDefined) map("${arg.argName}")=${arg.safeArgName}.get""" } else { s"""map("${arg.argName}") = ${arg.safeArgName}""" }