diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala index 70e610aab55a..eeeb621d1480 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala @@ -19,6 +19,7 @@ package org.apache.mxnet import org.scalatest.{BeforeAndAfterAll, FunSuite} + class SymbolSuite extends FunSuite with BeforeAndAfterAll { test("symbol compose") { diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 309f164c84a4..6c4eb2f9a03f 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -17,8 +17,6 @@ package org.apache.mxnet -import org.apache.mxnet.init.Base._ -import org.apache.mxnet.utils.CToScalaUtils import java.io._ import java.security.MessageDigest @@ -29,13 +27,11 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala * The code will be executed during Macros stage and file live in Core stage */ -private[mxnet] object APIDocGenerator{ - case class absClassArg(argName : String, argType : String, argDesc : String, isOptional : Boolean) - case class absClassFunction(name : String, desc : String, - listOfArgs: List[absClassArg], returnType : String) +private[mxnet] object APIDocGenerator extends GeneratorBase { + type absClassArg = Arg + type absClassFunction = Func - - def main(args: Array[String]) : Unit = { + def main(args: Array[String]): Unit = { val FILE_PATH = args(0) val hashCollector = ListBuffer[String]() hashCollector += absClassGen(FILE_PATH, true) @@ -47,68 +43,70 @@ private[mxnet] object APIDocGenerator{ val finalHash = hashCollector.mkString("\n") } - def MD5Generator(input : String) : String = { + def MD5Generator(input: String): String = { val md = MessageDigest.getInstance("MD5") md.update(input.getBytes("UTF-8")) val digest = md.digest() org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest) } - def absRndClassGen(FILE_PATH : String, isSymbol : Boolean) : String = { - typeSafeClassGen( - getSymbolNDArrayMethods(isSymbol) - .filter(f => f.name.startsWith("_random") || f.name.startsWith("_sample")) - .map(f => f.copy(name = f.name.stripPrefix("_"))), + def absRndClassGen(FILE_PATH: String, isSymbol: Boolean): String = { + val funcs = getSymbolNDArrayMethods(isSymbol) + .filter(f => f.name.startsWith("_sample_") || f.name.startsWith("_random_")) + .map(f => f.copy(name = f.name.stripPrefix("_"))) + val body = funcs.map(func => { + val scalaDoc = generateAPIDocFromBackend(func) + val decl = generateRandomAPISignature(func, isSymbol) + s"$scalaDoc\n$decl" + }) + writeFile( FILE_PATH, if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase", - isSymbol - ) + body) } - def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = { + def absClassGen(FILE_PATH: String, isSymbol: Boolean): String = { val notGenerated = Set("Custom") - typeSafeClassGen( - getSymbolNDArrayMethods(isSymbol) - .filterNot(_.name.startsWith("_")) - .filterNot(ele => notGenerated.contains(ele.name)), + val funcs = getSymbolNDArrayMethods(isSymbol) + .filterNot(_.name.startsWith("_")) + .filterNot(ele => notGenerated.contains(ele.name)) + val body = funcs.map(func => { + val scalaDoc = generateAPIDocFromBackend(func) + val decl = generateAPISignature(func, isSymbol) + s"$scalaDoc\n$decl" + }) + writeFile( FILE_PATH, if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase", - isSymbol - ) + body) } - def typeSafeClassGen(absClassFunctions: Seq[absClassFunction], FILE_PATH: String, - packageName: String, isSymbol: Boolean): String = { - val absFuncs = absClassFunctions - .map(absClassFunction => { - val scalaDoc = generateAPIDocFromBackend(absClassFunction) - val defBody = generateAPISignature(absClassFunction, isSymbol) - s"$scalaDoc\n$defBody" - }) - writeFile(FILE_PATH, packageName, absFuncs) - } - - def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = { - // scalastyle:off + def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = { val absClassFunctions = getSymbolNDArrayMethods(isSymbol) val absFuncs = absClassFunctions .filterNot(_.name.startsWith("_")) .map(absClassFunction => { - val scalaDoc = generateAPIDocFromBackend(absClassFunction, false) - if (isSymbol) { - val defBody = s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): org.apache.mxnet.Symbol" - s"$scalaDoc\n$defBody" - } else { - val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn" - val defBody = s"def ${absClassFunction.name}(args: Any*) : org.apache.mxnet.NDArrayFuncReturn" - s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody" - } - }) + val scalaDoc = generateAPIDocFromBackend(absClassFunction, false) + if (isSymbol) { + val defBody = + s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)" + + s"(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): " + + s"org.apache.mxnet.Symbol" + s"$scalaDoc\n$defBody" + } else { + val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)" + + s"(args: Any*): " + + s"org.apache.mxnet.NDArrayFuncReturn" + val defBody = s"def ${absClassFunction.name}(args: Any*): " + + s"org.apache.mxnet.NDArrayFuncReturn" + s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody" + } + }) val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase" writeFile(FILE_PATH, packageName, absFuncs) } - def writeFile(FILE_PATH: String, packageName: String, absFuncs: Seq[String]): String = { + def writeFile(FILE_PATH: String, packageName: String, body: Seq[String]): String = { val apacheLicence = """/* |* Licensed to the Apache Software Foundation (ASF) under one or more @@ -137,7 +135,7 @@ private[mxnet] object APIDocGenerator{ |$packageDef |$imports |$absClassDef { - |${absFuncs.mkString("\n")} + |${body.mkString("\n")} |}""".stripMargin val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala")) pw.write(finalStr) @@ -146,20 +144,15 @@ private[mxnet] object APIDocGenerator{ } // Generate ScalaDoc type - def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = true) : String = { + def generateAPIDocFromBackend(func: absClassFunction, withParam: Boolean = true): String = { val desc = ArrayBuffer[String]() desc += " *
" - func.desc.split("\n").foreach({ currStr => + func.desc.split("\n").foreach({ currStr => desc += s" * $currStr" }) desc += " *" val params = func.listOfArgs.map({ absClassArg => - val currArgName = absClassArg.argName match { - case "var" => "vari" - case "type" => "typeOf" - case _ => absClassArg.argName - } - s" * @param $currArgName\t\t${absClassArg.argDesc}" + s" * @param ${absClassArg.safeArgName}\t\t${absClassArg.argDesc}" }) val returnType = s" * @return ${func.returnType}" if (withParam) { @@ -169,64 +162,31 @@ private[mxnet] object APIDocGenerator{ } } - def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = { - var argDef = ListBuffer[String]() - func.listOfArgs.foreach(absClassArg => { - val currArgName = absClassArg.argName match { - case "var" => "vari" - case "type" => "typeOf" - case _ => absClassArg.argName - } - if (absClassArg.isOptional) { - argDef += s"$currArgName : Option[${absClassArg.argType}] = None" - } - else { - argDef += s"$currArgName : ${absClassArg.argType}" - } - }) - var returnType = func.returnType + def generateRandomAPISignature(func: absClassFunction, isSymbol: Boolean): String = { + generateAPISignature(func, isSymbol) + } + + def generateAPISignature(func: absClassFunction, isSymbol: Boolean): String = { + val argDef = ListBuffer[String]() + + argDef ++= buildArgDefs(func) + if (isSymbol) { argDef += "name : String = null" argDef += "attr : Map[String, String] = null" } else { argDef += "out : Option[NDArray] = None" - returnType = "org.apache.mxnet.NDArrayFuncReturn" } + + val returnType = func.returnType + val experimentalTag = "@Experimental" s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType" } // List and add all the atomic symbol functions to current module. - private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = { - val opNames = ListBuffer.empty[String] - val returnType = if (isSymbol) "Symbol" else "NDArray" - _LIB.mxListAllOpNames(opNames) - // TODO: Add '_linalg_', '_sparse_', '_image_' support - // TODO: Add Filter to the same location in case of refactor - opNames.map(opName => { - val opHandle = new RefLong - _LIB.nnGetOpHandle(opName, opHandle) - makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType) - }).toList + private def getSymbolNDArrayMethods(isSymbol: Boolean): List[absClassFunction] = { + buildFunctionList(isSymbol) } - // Create an atomic symbol function by handle and function name. - private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String) - : absClassFunction = { - val name = new RefString - val desc = new RefString - val keyVarNumArgs = new RefString - val numArgs = new RefInt - val argNames = ListBuffer.empty[String] - val argTypes = ListBuffer.empty[String] - val argDescs = ListBuffer.empty[String] - - _LIB.mxSymbolGetAtomicSymbolInfo( - handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs) - val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) => - val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, returnType) - absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2) - } - absClassFunction(aliasName, desc.value, argList.toList, returnType) - } } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala new file mode 100644 index 000000000000..e288a8ba214b --- /dev/null +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala @@ -0,0 +1,116 @@ +package org.apache.mxnet + +import org.apache.mxnet.init.Base.{RefInt, RefLong, RefString, _LIB} +import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils} + +import scala.collection.mutable.ListBuffer +import scala.reflect.macros.blackbox + +abstract class GeneratorBase { + type Handle = Long + + case class Arg(argName: String, argType: String, argDesc: String, isOptional: Boolean) { + def safeArgName: String = argName match { + case "var" => "vari" + case "type" => "typeOf" + case _ => argName + } + } + + case class Func(name: String, desc: String, listOfArgs: List[Arg], returnType: String) + + protected def buildFunctionList(isSymbol: Boolean): List[Func] = { + val opNames = ListBuffer.empty[String] + _LIB.mxListAllOpNames(opNames) + opNames.map(opName => { + val opHandle = new RefLong + _LIB.nnGetOpHandle(opName, opHandle) + makeAtomicFunction(opHandle.value, opName, isSymbol) + }).toList + } + + protected def makeAtomicFunction(handle: Handle, aliasName: String, isSymbol: Boolean): Func = { + val name = new RefString + val desc = new RefString + val keyVarNumArgs = new RefString + val numArgs = new RefInt + val argNames = ListBuffer.empty[String] + val argTypes = ListBuffer.empty[String] + val argDescs = ListBuffer.empty[String] + + _LIB.mxSymbolGetAtomicSymbolInfo( + handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs) + val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, argDescs) + val extraDoc: String = if (keyVarNumArgs.value != null && keyVarNumArgs.value.length > 0) { + s"This function support variable length of positional input (${keyVarNumArgs.value})." + } else { + "" + } + val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})" + val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n" + // scalastyle:off println + if (System.getenv("MXNET4J_PRINT_OP_DEF") != null + && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") { + println("Function definition:\n" + docStr) + } + // scalastyle:on println + val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) => + val family = if(isSymbol) "org.apache.mxnet.Symbol" else "org.apache.mxnet.NDArray" + val typeAndOption = + CToScalaUtils.argumentCleaner(argName, argType, family) + Arg(argName, typeAndOption._1, argDesc, typeAndOption._2) + } + val returnType = if(isSymbol) "org.apache.mxnet.Symbol" else "org.apache.mxnet.NDArrayFuncReturn" + Func(aliasName, desc.value, argList.toList, returnType) + } + + /** + * Generate class structure for all function APIs + * + * @param c + * @param funcDef DefDef type of function definitions + * @param annottees + * @return + */ + protected def structGeneration(c: blackbox.Context) + (funcDef: List[c.universe.DefDef], annottees: c.Expr[Any]*) + : c.Expr[Any] = { + import c.universe._ + val inputs = annottees.map(_.tree).toList + // pattern match on the inputs + val modDefs = inputs map { + case ClassDef(mods, name, something, template) => + val q = template match { + case Template(superMaybe, emptyValDef, defs) => + Template(superMaybe, emptyValDef, defs ++ funcDef) + case ex => + throw new IllegalArgumentException(s"Invalid template: $ex") + } + ClassDef(mods, name, something, q) + case ModuleDef(mods, name, template) => + val q = template match { + case Template(superMaybe, emptyValDef, defs) => + Template(superMaybe, emptyValDef, defs ++ funcDef) + case ex => + throw new IllegalArgumentException(s"Invalid template: $ex") + } + ModuleDef(mods, name, q) + case ex => + throw new IllegalArgumentException(s"Invalid macro input: $ex") + } + // wrap the result up in an Expr, and return it + val result = c.Expr(Block(modDefs, Literal(Constant()))) + result + } + + protected def buildArgDefs(func: Func): List[String] = { + func.listOfArgs.map(arg => + if (arg.isOptional) + s"${arg.safeArgName} : Option[${arg.argType}] = None" + else + s"${arg.safeArgName} : ${arg.argType}" + ) + } + + +} diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index 02e86d2fbcab..c33465d96845 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -17,29 +17,25 @@ package org.apache.mxnet -import org.apache.mxnet.init.Base._ -import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils} - import scala.annotation.StaticAnnotation -import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.collection.mutable.ListBuffer import scala.language.experimental.macros import scala.reflect.macros.blackbox private[mxnet] class AddNDArrayFunctions(isContrib: Boolean) extends StaticAnnotation { private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.addDefs } - private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation { private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.typeSafeAPIDefs } - private[mxnet] class AddNDArrayRandomAPIs(isContrib: Boolean) extends StaticAnnotation { private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.typeSafeRandomAPIDefs } -private[mxnet] object NDArrayMacro { - case class NDArrayArg(argName: String, argType: String, isOptional : Boolean) - case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg]) + +private[mxnet] object NDArrayMacro extends GeneratorBase { + type NDArrayArg = Arg + type NDArrayFunction = Func // scalastyle:off havetype def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { @@ -53,8 +49,11 @@ private[mxnet] object NDArrayMacro { } // scalastyle:off havetype - private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule() + private val ndarrayFunctions: List[NDArrayFunction] = buildFunctionList(false) + /** + * Implementation for fixed input API structure + */ private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ @@ -127,28 +126,20 @@ private[mxnet] object NDArrayMacro { (ndarrayfunction: NDArrayFunction): c.universe.DefDef = { import c.universe._ + val returnType = "org.apache.mxnet.NDArrayFuncReturn" + // Construct argument field - var argDef = ListBuffer[String]() + val argDef = ListBuffer[String]() + argDef ++= buildArgDefs(ndarrayfunction) + // Construct Implementation field var impl = ListBuffer[String]() impl += "val map = scala.collection.mutable.Map[String, Any]()" impl += "val args = scala.collection.mutable.ArrayBuffer.empty[NDArray]" + ndarrayfunction.listOfArgs.foreach({ ndarrayarg => - // var is a special word used to define variable in Scala, - // need to changed to something else in order to make it work - val currArgName = ndarrayarg.argName match { - case "var" => "vari" - case "type" => "typeOf" - case default => ndarrayarg.argName - } - if (ndarrayarg.isOptional) { - argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None" - } - else { - argDef += s"${currArgName} : ${ndarrayarg.argType}" - } // NDArray arg implementation - val returnType = "org.apache.mxnet.NDArray" + val arrayType = "org.apache.mxnet.NDArray" // TODO: Currently we do not add place holder for NDArray // Example: an NDArray operator like the following format @@ -156,18 +147,19 @@ private[mxnet] object NDArrayMacro { // If we place nd.foo(arg1, arg3 = arg3), do we need to add place holder for arg2? // What it should be? val base = - if (ndarrayarg.argType.equals(returnType)) { - s"args += $currArgName" - } else if (ndarrayarg.argType.equals(s"Array[$returnType]")) { - s"args ++= $currArgName" + if (ndarrayarg.argType.equals(arrayType)) { + s"args += ${ndarrayarg.safeArgName}" + } else if (ndarrayarg.argType.equals(s"Array[$arrayType]")) { + s"args ++= ${ndarrayarg.safeArgName}" } else { - "map(\"" + ndarrayarg.argName + "\") = " + currArgName + "map(\"" + ndarrayarg.argName + "\") = " + ndarrayarg.safeArgName } impl.append( - if (ndarrayarg.isOptional) s"if (!$currArgName.isEmpty) $base.get" + if (ndarrayarg.isOptional) s"if (!${ndarrayarg.safeArgName}.isEmpty) $base.get" else base ) }) + // add default out parameter argDef += "out : Option[NDArray] = None" impl += "if (!out.isEmpty) map(\"out\") = out.get" @@ -175,90 +167,10 @@ private[mxnet] object NDArrayMacro { impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)" // scalastyle:on // Combine and build the function string - val returnType = "org.apache.mxnet.NDArrayFuncReturn" var finalStr = s"def ${ndarrayfunction.name}" finalStr += s" (${argDef.mkString(",")}) : $returnType" finalStr += s" = {${impl.mkString("\n")}}" c.parse(finalStr).asInstanceOf[DefDef] } - private def structGeneration(c: blackbox.Context) - (funcDef : List[c.universe.DefDef], annottees: c.Expr[Any]*) - : c.Expr[Any] = { - import c.universe._ - val inputs = annottees.map(_.tree).toList - // pattern match on the inputs - val modDefs = inputs map { - case ClassDef(mods, name, something, template) => - val q = template match { - case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ funcDef) - case ex => - throw new IllegalArgumentException(s"Invalid template: $ex") - } - ClassDef(mods, name, something, q) - case ModuleDef(mods, name, template) => - val q = template match { - case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ funcDef) - case ex => - throw new IllegalArgumentException(s"Invalid template: $ex") - } - ModuleDef(mods, name, q) - case ex => - throw new IllegalArgumentException(s"Invalid macro input: $ex") - } - // wrap the result up in an Expr, and return it - val result = c.Expr(Block(modDefs, Literal(Constant()))) - result - } - - - - - // List and add all the atomic symbol functions to current module. - private def initNDArrayModule(): List[NDArrayFunction] = { - val opNames = ListBuffer.empty[String] - _LIB.mxListAllOpNames(opNames) - opNames.map(opName => { - val opHandle = new RefLong - _LIB.nnGetOpHandle(opName, opHandle) - makeNDArrayFunction(opHandle.value, opName) - }).toList - } - - // Create an atomic symbol function by handle and function name. - private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String) - : NDArrayFunction = { - val name = new RefString - val desc = new RefString - val keyVarNumArgs = new RefString - val numArgs = new RefInt - val argNames = ListBuffer.empty[String] - val argTypes = ListBuffer.empty[String] - val argDescs = ListBuffer.empty[String] - - _LIB.mxSymbolGetAtomicSymbolInfo( - handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs) - val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, argDescs) - val extraDoc: String = if (keyVarNumArgs.value != null && keyVarNumArgs.value.length > 0) { - s"This function support variable length of positional input (${keyVarNumArgs.value})." - } else { - "" - } - val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})" - val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n" - // scalastyle:off println - if (System.getenv("MXNET4J_PRINT_OP_DEF") != null - && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") { - println("NDArray function definition:\n" + docStr) - } - // scalastyle:on println - val argList = argNames zip argTypes map { case (argName, argType) => - val typeAndOption = - CToScalaUtils.argumentCleaner(argName, argType, "org.apache.mxnet.NDArray") - NDArrayArg(argName, typeAndOption._1, typeAndOption._2) - } - NDArrayFunction(aliasName, argList.toList) - } } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 88eca6955068..75c1a2f93610 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -21,26 +21,21 @@ import scala.annotation.StaticAnnotation import scala.collection.mutable.ListBuffer import scala.language.experimental.macros import scala.reflect.macros.blackbox -import org.apache.mxnet.init.Base._ -import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils} private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends StaticAnnotation { private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addDefs } - private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends StaticAnnotation { private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.typeSafeAPIDefs } - private[mxnet] class AddSymbolRandomAPIs(isContrib: Boolean) extends StaticAnnotation { private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.typedRandomAPIDefs } - -private[mxnet] object SymbolImplMacros { - case class SymbolArg(argName: String, argType: String, isOptional : Boolean) - case class SymbolFunction(name: String, listOfArgs: List[SymbolArg]) - +private[mxnet] object SymbolImplMacros extends GeneratorBase { + type SymbolArg = Arg + type SymbolFunction = Func + // scalastyle:off havetype def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { impl(c)(annottees: _*) @@ -53,7 +48,7 @@ private[mxnet] object SymbolImplMacros { } // scalastyle:on havetype - private val symbolFunctions: List[SymbolFunction] = initSymbolModule() + private val symbolFunctions: List[SymbolFunction] = buildFunctionList(true) /** * Implementation for fixed input API structure @@ -72,31 +67,17 @@ private[mxnet] object SymbolImplMacros { } val functionDefs = newSymbolFunctions map { symbolfunction => - val funcName = symbolfunction.name - val tName = TermName(funcName) - q""" + val funcName = symbolfunction.name + val tName = TermName(funcName) + q""" def $tName(name : String = null, attr : Map[String, String] = null) (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null) : org.apache.mxnet.Symbol = { createSymbolGeneral($funcName,name,attr,args,kwargs) } """.asInstanceOf[DefDef] - } - - structGeneration(c)(functionDefs, annottees : _*) - } - - /** - * Implementation for Dynamic typed API Symbol.random.