From 251d2cd58f99546819bc61acc5931e838f631ed3 Mon Sep 17 00:00:00 2001 From: mathieu Date: Mon, 5 Nov 2018 18:55:45 +0100 Subject: [PATCH] [MXNET-918] Introduce Random module / Refact code generation (#13038) * refactor code gen * remove xxxAPIMacroBase (overkill) * CI errors / scala-style * PR review comments --- .../benchmark/ObjectDetectionBenchmark.java | 2 +- .../org/apache/mxnet/APIDocGenerator.scala | 311 +++++++----------- .../org/apache/mxnet/GeneratorBase.scala | 163 +++++++++ .../scala/org/apache/mxnet/NDArrayMacro.scala | 263 +++++---------- .../scala/org/apache/mxnet/SymbolMacro.scala | 250 ++++---------- 5 files changed, 440 insertions(+), 549 deletions(-) create mode 100644 scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/ObjectDetectionBenchmark.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/ObjectDetectionBenchmark.java index 485e0afa3e46..257ea3241626 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/ObjectDetectionBenchmark.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/ObjectDetectionBenchmark.java @@ -57,7 +57,7 @@ public void runBatchInference() { List nd = new ArrayList<>(); NDArray[] temp = new NDArray[batchSize]; for (int i = 0; i < batchSize; i++) temp[i] = img.copy(); - NDArray batched = NDArray.concat(temp, batchSize).setdim(0).invoke().get(); + NDArray batched = NDArray.concat(temp, batchSize, 0, null)[0]; nd.add(batched); objDet.objectDetectWithNDArray(nd, 3); } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index f2326868e8e7..16592085c7b5 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -17,196 +17,147 @@ package org.apache.mxnet -import org.apache.mxnet.init.Base._ -import org.apache.mxnet.utils.CToScalaUtils import java.io._ import java.security.MessageDigest -import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.collection.mutable.ListBuffer /** * This object will generate the Scala documentation of the new Scala API * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala * The code will be executed during Macros stage and file live in Core stage */ -private[mxnet] object APIDocGenerator{ - case class absClassArg(argName : String, argType : String, argDesc : String, isOptional : Boolean) - case class absClassFunction(name : String, desc : String, - listOfArgs: List[absClassArg], returnType : String) +private[mxnet] object APIDocGenerator extends GeneratorBase { - - def main(args: Array[String]) : Unit = { + def main(args: Array[String]): Unit = { val FILE_PATH = args(0) val hashCollector = ListBuffer[String]() - hashCollector += absClassGen(FILE_PATH, true) - hashCollector += absClassGen(FILE_PATH, false) + hashCollector += typeSafeClassGen(FILE_PATH, true) + hashCollector += typeSafeClassGen(FILE_PATH, false) hashCollector += nonTypeSafeClassGen(FILE_PATH, true) hashCollector += nonTypeSafeClassGen(FILE_PATH, false) - // Generate Java API documentation - hashCollector += javaClassGen(FILE_PATH + "javaapi/") + hashCollector += javaClassGen(FILE_PATH) val finalHash = hashCollector.mkString("\n") } - def MD5Generator(input : String) : String = { + def MD5Generator(input: String): String = { val md = MessageDigest.getInstance("MD5") md.update(input.getBytes("UTF-8")) val digest = md.digest() org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest) } - def fileGen(filePath : String, packageName : String, packageDef : String, - absFuncs : List[String]) : String = { - val apacheLicense = - """/* - |* Licensed to the Apache Software Foundation (ASF) under one or more - |* contributor license agreements. See the NOTICE file distributed with - |* this work for additional information regarding copyright ownership. - |* The ASF licenses this file to You under the Apache License, Version 2.0 - |* (the "License"); you may not use this file except in compliance with - |* the License. You may obtain a copy of the License at - |* - |* http://www.apache.org/licenses/LICENSE-2.0 - |* - |* Unless required by applicable law or agreed to in writing, software - |* distributed under the License is distributed on an "AS IS" BASIS, - |* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - |* See the License for the specific language governing permissions and - |* limitations under the License. - |*/ - |""".stripMargin - val scalaStyle = "// scalastyle:off" - val imports = "import org.apache.mxnet.annotation.Experimental" - val absClassDef = s"abstract class $packageName" + def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = { + val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false) + .map { func => + val scalaDoc = generateAPIDocFromBackend(func) + val decl = generateAPISignature(func, isSymbol) + s"$scalaDoc\n$decl" + } - val finalStr = - s"""$apacheLicense - |$scalaStyle - |$packageDef - |$imports - |$absClassDef { - |${absFuncs.mkString("\n")} - |}""".stripMargin - val pw = new PrintWriter(new File(filePath + s"$packageName.scala")) - pw.write(finalStr) - pw.close() - MD5Generator(finalStr) + writeFile( + FILE_PATH, + if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase", + "package org.apache.mxnet", + generated) } - def absClassGen(filePath : String, isSymbol : Boolean) : String = { - val absClassFunctions = getSymbolNDArrayMethods(isSymbol) - // Defines Operators that should not generated - val notGenerated = Set("Custom") - // TODO: Add Filter to the same location in case of refactor - val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_")) - .filterNot(ele => notGenerated.contains(ele.name)) - .map(absClassFunction => { - val scalaDoc = generateAPIDocFromBackend(absClassFunction) - val defBody = generateAPISignature(absClassFunction, isSymbol) - s"$scalaDoc\n$defBody" - }) - val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase" - val packageDef = "package org.apache.mxnet" - fileGen(filePath, packageName, packageDef, absFuncs) + def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = { + val absFuncs = functionsToGenerate(isSymbol, isContrib = false) + .map { func => + val scalaDoc = generateAPIDocFromBackend(func, false) + if (isSymbol) { + s"""$scalaDoc + |def ${func.name}(name : String = null, attr : Map[String, String] = null) + | (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): + | org.apache.mxnet.Symbol + """.stripMargin + } else { + s"""$scalaDoc + |def ${func.name}(kwargs: Map[String, Any] = null) + | (args: Any*): org.apache.mxnet.NDArrayFuncReturn + | + |$scalaDoc + |def ${func.name}(args: Any*): org.apache.mxnet.NDArrayFuncReturn + """.stripMargin + } + } + + writeFile( + FILE_PATH, + if (isSymbol) "SymbolBase" else "NDArrayBase", + "package org.apache.mxnet", + absFuncs) } def javaClassGen(filePath : String) : String = { val notGenerated = Set("Custom") - val absClassFunctions = getSymbolNDArrayMethods(false, true) - // TODO: Add Filter to the same location in case of refactor - val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_")) - .filterNot(ele => notGenerated.contains(ele.name)) - .map(absClassFunction => { + val absClassFunctions = functionsToGenerate(false, false, true) + val absFuncs = absClassFunctions.filterNot(ele => notGenerated.contains(ele.name)) + .groupBy(_.name.toLowerCase).map(ele => { + // Pattern matching for not generating depreciated method + if (ele._2.length == 1) ele._2.head + else { + if (ele._2.head.name.head.isLower) ele._2.head + else ele._2.last + } + }).map(absClassFunction => { generateJavaAPISignature(absClassFunction) - }) + }).toSeq val packageName = "NDArrayBase" val packageDef = "package org.apache.mxnet.javaapi" - fileGen(filePath, packageName, packageDef, absFuncs) + writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs) } - def nonTypeSafeClassGen(filePath : String, isSymbol : Boolean) : String = { - // scalastyle:off - val absClassFunctions = getSymbolNDArrayMethods(isSymbol) - val absFuncs = absClassFunctions.map(absClassFunction => { - val scalaDoc = generateAPIDocFromBackend(absClassFunction, false) - if (isSymbol) { - val defBody = s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): org.apache.mxnet.Symbol" - s"$scalaDoc\n$defBody" - } else { - val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn" - val defBody = s"def ${absClassFunction.name}(args: Any*) : org.apache.mxnet.NDArrayFuncReturn" - s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody" - } - }) - val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase" - val packageDef = "package org.apache.mxnet" - fileGen(filePath, packageName, packageDef, absFuncs) - } + def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = { + val desc = func.desc.split("\n") + .mkString(" *
\n", "\n  * ", "  * 
\n") - /** - * Some of the C++ type name is not valid in Scala - * such as var and type. This method is to convert - * them into other names to get it passed - * @param in the input String - * @return converted name string - */ - def safetyNameCheck(in : String) : String = { - in match { - case "var" => "vari" - case "type" => "typeOf" - case _ => in + val params = func.listOfArgs.map { absClassArg => + s" * @param ${absClassArg.safeArgName}\t\t${absClassArg.argDesc}" } - } - // Generate ScalaDoc type - def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = true) : String = { - val desc = ArrayBuffer[String]() - desc += " *
"
-      func.desc.split("\n").foreach({ currStr =>
-      desc += s"  * $currStr"
-    })
-    desc += "  * 
" - val params = func.listOfArgs.map({ absClassArg => - val currArgName = safetyNameCheck(absClassArg.argName) - s" * @param $currArgName\t\t${absClassArg.argDesc}" - }) val returnType = s" * @return ${func.returnType}" + if (withParam) { - s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */" + s""" /** + |$desc + |${params.mkString("\n")} + |$returnType + | */""".stripMargin } else { - s" /**\n${desc.mkString("\n")}\n$returnType\n */" + s""" /** + |$desc + |$returnType + | */""".stripMargin } } - def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = { - var argDef = ListBuffer[String]() - func.listOfArgs.foreach(absClassArg => { - val currArgName = safetyNameCheck(absClassArg.argName) - if (absClassArg.isOptional) { - argDef += s"$currArgName : Option[${absClassArg.argType}] = None" - } - else { - argDef += s"$currArgName : ${absClassArg.argType}" - } - }) - var returnType = func.returnType + def generateAPISignature(func: Func, isSymbol: Boolean): String = { + val argDef = ListBuffer[String]() + + argDef ++= typedFunctionCommonArgDef(func) + if (isSymbol) { argDef += "name : String = null" argDef += "attr : Map[String, String] = null" } else { argDef += "out : Option[NDArray] = None" - returnType = "org.apache.mxnet.NDArrayFuncReturn" } - val experimentalTag = "@Experimental" - s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType" + + val returnType = func.returnType + + s"""@Experimental + |def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin } - def generateJavaAPISignature(func : absClassFunction) : String = { + def generateJavaAPISignature(func : Func) : String = { val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2 var argDef = ListBuffer[String]() var classDef = ListBuffer[String]() var requiredParam = ListBuffer[String]() func.listOfArgs.foreach(absClassArg => { - val currArgName = safetyNameCheck(absClassArg.argName) + val currArgName = absClassArg.safeArgName // scalastyle:off if (absClassArg.isOptional && useParamObject) { classDef += @@ -240,15 +191,15 @@ private[mxnet] object APIDocGenerator{ | def getOut() = this.out | """.stripMargin s"""$scalaDocNoParam - | $experimentalTag - | def ${func.name}(po: ${func.name}Param) : $returnType - | /** - | * This Param Object is specifically used for ${func.name} - | ${requiredParam.mkString("\n")} - | */ - | class ${func.name}Param(${argDef.mkString(",")}) { - | ${classDef.mkString("\n ")} - | }""".stripMargin + | $experimentalTag + | def ${func.name}(po: ${func.name}Param) : $returnType + | /** + | * This Param Object is specifically used for ${func.name} + | ${requiredParam.mkString("\n")} + | */ + | class ${func.name}Param(${argDef.mkString(",")}) { + | ${classDef.mkString("\n ")} + | }""".stripMargin } else { argDef += "out : NDArray" s"""$scalaDoc @@ -258,48 +209,40 @@ private[mxnet] object APIDocGenerator{ } } + def writeFile(FILE_PATH: String, className: String, packageDef: String, + absFuncs: Seq[String]): String = { - // List and add all the atomic symbol functions to current module. - private def getSymbolNDArrayMethods(isSymbol : Boolean, - isJava : Boolean = false): List[absClassFunction] = { - val opNames = ListBuffer.empty[String] - val returnType = if (isSymbol) "Symbol" else "NDArray" - val returnHeader = if (isJava) "org.apache.mxnet.javaapi." else "org.apache.mxnet." - _LIB.mxListAllOpNames(opNames) - // TODO: Add '_linalg_', '_sparse_', '_image_' support - // TODO: Add Filter to the same location in case of refactor - opNames.map(opName => { - val opHandle = new RefLong - _LIB.nnGetOpHandle(opName, opHandle) - makeAtomicSymbolFunction(opHandle.value, opName, returnHeader + returnType) - }).filterNot(_.name.startsWith("_")).groupBy(_.name.toLowerCase).map(ele => { - // Pattern matching for not generating depreciated method - if (ele._2.length == 1) ele._2.head - else { - if (ele._2.head.name.head.isLower) ele._2.head - else ele._2.last - } - }).toList - } - - // Create an atomic symbol function by handle and function name. - private def makeAtomicSymbolFunction(handle: SymbolHandle, - aliasName: String, returnType : String) - : absClassFunction = { - val name = new RefString - val desc = new RefString - val keyVarNumArgs = new RefString - val numArgs = new RefInt - val argNames = ListBuffer.empty[String] - val argTypes = ListBuffer.empty[String] - val argDescs = ListBuffer.empty[String] + val finalStr = + s"""/* + |* Licensed to the Apache Software Foundation (ASF) under one or more + |* contributor license agreements. See the NOTICE file distributed with + |* this work for additional information regarding copyright ownership. + |* The ASF licenses this file to You under the Apache License, Version 2.0 + |* (the "License"); you may not use this file except in compliance with + |* the License. You may obtain a copy of the License at + |* + |* http://www.apache.org/licenses/LICENSE-2.0 + |* + |* Unless required by applicable law or agreed to in writing, software + |* distributed under the License is distributed on an "AS IS" BASIS, + |* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + |* See the License for the specific language governing permissions and + |* limitations under the License. + |*/ + | + |$packageDef + | + |import org.apache.mxnet.annotation.Experimental + | + |// scalastyle:off + |abstract class $className { + |${absFuncs.mkString("\n")} + |}""".stripMargin - _LIB.mxSymbolGetAtomicSymbolInfo( - handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs) - val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) => - val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, returnType) - new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2) - } - new absClassFunction(aliasName, desc.value, argList.toList, returnType) + val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala")) + pw.write(finalStr) + pw.close() + MD5Generator(finalStr) } -} + +} \ No newline at end of file diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala new file mode 100644 index 000000000000..9245ef1b437f --- /dev/null +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import org.apache.mxnet.init.Base.{RefInt, RefLong, RefString, _LIB} +import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils} + +import scala.collection.mutable.ListBuffer +import scala.reflect.macros.blackbox + +abstract class GeneratorBase { + type Handle = Long + + case class Arg(argName: String, argType: String, argDesc: String, isOptional: Boolean) { + def safeArgName: String = argName match { + case "var" => "vari" + case "type" => "typeOf" + case _ => argName + } + } + + case class Func(name: String, desc: String, listOfArgs: List[Arg], returnType: String) + + def functionsToGenerate(isSymbol: Boolean, isContrib: Boolean, + isJava: Boolean = false): List[Func] = { + val l = getBackEndFunctions(isSymbol, isJava) + if (isContrib) { + l.filter(func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) + } else { + l.filterNot(_.name.startsWith("_")) + } + } + + def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = { + // Operators that should not be generated + val notGenerated = Set("Custom") + + val l = getBackEndFunctions(isSymbol) + val res = if (isContrib) { + l.filter(func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) + } else { + l.filterNot(_.name.startsWith("_")) + } + res.filterNot(ele => notGenerated.contains(ele.name)) + } + + protected def getBackEndFunctions(isSymbol: Boolean, isJava: Boolean = false): List[Func] = { + val opNames = ListBuffer.empty[String] + _LIB.mxListAllOpNames(opNames) + opNames.map(opName => { + val opHandle = new RefLong + _LIB.nnGetOpHandle(opName, opHandle) + makeAtomicFunction(opHandle.value, opName, isSymbol, isJava) + }).toList + } + + private def makeAtomicFunction(handle: Handle, aliasName: String, + isSymbol: Boolean, isJava: Boolean): Func = { + val name = new RefString + val desc = new RefString + val keyVarNumArgs = new RefString + val numArgs = new RefInt + val argNames = ListBuffer.empty[String] + val argTypes = ListBuffer.empty[String] + val argDescs = ListBuffer.empty[String] + + _LIB.mxSymbolGetAtomicSymbolInfo( + handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs) + val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, argDescs) + val extraDoc: String = if (keyVarNumArgs.value != null && keyVarNumArgs.value.length > 0) { + s"This function support variable length of positional input (${keyVarNumArgs.value})." + } else { + "" + } + val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})" + val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n" + + val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) => + val family = if (isJava) "org.apache.mxnet.javaapi.NDArray" + else if (isSymbol) "org.apache.mxnet.Symbol" + else "org.apache.mxnet.NDArray" + val typeAndOption = + CToScalaUtils.argumentCleaner(argName, argType, family) + Arg(argName, typeAndOption._1, argDesc, typeAndOption._2) + } + val returnType = + if (isJava) "Array[org.apache.mxnet.javaapi.NDArray]" + else if (isSymbol) "org.apache.mxnet.Symbol" + else "org.apache.mxnet.NDArrayFuncReturn" + Func(aliasName, desc.value, argList.toList, returnType) + } + + /** + * Generate class structure for all function APIs + * + * @param c + * @param funcDef DefDef type of function definitions + * @param annottees + * @return + */ + protected def structGeneration(c: blackbox.Context) + (funcDef: List[c.universe.DefDef], annottees: c.Expr[Any]*) + : c.Expr[Any] = { + import c.universe._ + val inputs = annottees.map(_.tree).toList + // pattern match on the inputs + val modDefs = inputs map { + case ClassDef(mods, name, something, template) => + val q = template match { + case Template(superMaybe, emptyValDef, defs) => + Template(superMaybe, emptyValDef, defs ++ funcDef) + case ex => + throw new IllegalArgumentException(s"Invalid template: $ex") + } + ClassDef(mods, name, something, q) + case ModuleDef(mods, name, template) => + val q = template match { + case Template(superMaybe, emptyValDef, defs) => + Template(superMaybe, emptyValDef, defs ++ funcDef) + case ex => + throw new IllegalArgumentException(s"Invalid template: $ex") + } + ModuleDef(mods, name, q) + case ex => + throw new IllegalArgumentException(s"Invalid macro input: $ex") + } + // wrap the result up in an Expr, and return it + val result = c.Expr(Block(modDefs, Literal(Constant()))) + result + } + + protected def typedFunctionCommonArgDef(func: Func): List[String] = { + // build function argument definition, with optionality, and safe names + func.listOfArgs.map(arg => + if (arg.isOptional) { + // let's avoid a stupid Option[Array[...]] + if (arg.argType.startsWith("Array[")) { + s"${arg.safeArgName} : ${arg.argType} = Array.empty" + } else { + s"${arg.safeArgName} : Option[${arg.argType}] = None" + } + } + else { + s"${arg.safeArgName} : ${arg.argType}" + } + ) + } +} diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index 2d3a1c7ec5af..d85abe1ecc4f 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -17,11 +17,8 @@ package org.apache.mxnet -import org.apache.mxnet.init.Base._ -import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils} - import scala.annotation.StaticAnnotation -import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.collection.mutable.ListBuffer import scala.language.experimental.macros import scala.reflect.macros.blackbox @@ -30,207 +27,111 @@ private[mxnet] class AddNDArrayFunctions(isContrib: Boolean) extends StaticAnnot } private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation { - private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.typeSafeAPIDefs + private[mxnet] def macroTransform(annottees: Any*) = macro TypedNDArrayAPIMacro.typeSafeAPIDefs } -private[mxnet] object NDArrayMacro { - case class NDArrayArg(argName: String, argType: String, isOptional : Boolean) - case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg]) - - // scalastyle:off havetype - def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { - impl(c)(annottees: _*) - } - def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { - typeSafeAPIImpl(c)(annottees: _*) - } - // scalastyle:off havetype - - private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule() +private[mxnet] object NDArrayMacro extends GeneratorBase { - private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { + def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ - val isContrib: Boolean = c.prefix.tree match { case q"new AddNDArrayFunctions($b)" => c.eval[Boolean](c.Expr(b)) } - val newNDArrayFunctions = { - if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_")) - else ndarrayFunctions.filterNot(_.name.startsWith("_")) - } - - val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction => - val funcName = NDArrayfunction.name - val termName = TermName(funcName) - Seq( - // scalastyle:off - // (yizhi) We are investigating a way to make these functions type-safe - // and waiting to see the new approach is stable enough. - // Thus these functions may be deprecated in the future. - // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*) - q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef], - // e.g def transpose(args: Any*) - q"def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef] - // scalastyle:on - ) - } - - structGeneration(c)(functionDefs, annottees : _*) + impl(c)(isContrib, annottees: _*) } - private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = { + private def impl(c: blackbox.Context) + (isContrib: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ - val isContrib: Boolean = c.prefix.tree match { - case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b)) - } - // Defines Operators that should not generated - val notGenerated = Set("Custom") - - val newNDArrayFunctions = { - if (isContrib) ndarrayFunctions.filter( - func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) - else ndarrayFunctions.filterNot(_.name.startsWith("_")) - }.filterNot(ele => notGenerated.contains(ele.name)) - - val functionDefs = newNDArrayFunctions.map { ndarrayfunction => - - // Construct argument field - var argDef = ListBuffer[String]() - // Construct Implementation field - var impl = ListBuffer[String]() - impl += "val map = scala.collection.mutable.Map[String, Any]()" - impl += "val args = scala.collection.mutable.ArrayBuffer.empty[NDArray]" - ndarrayfunction.listOfArgs.foreach({ ndarrayarg => - // var is a special word used to define variable in Scala, - // need to changed to something else in order to make it work - val currArgName = ndarrayarg.argName match { - case "var" => "vari" - case "type" => "typeOf" - case default => ndarrayarg.argName - } - if (ndarrayarg.isOptional) { - argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None" - } - else { - argDef += s"${currArgName} : ${ndarrayarg.argType}" - } - // NDArray arg implementation - val returnType = "org.apache.mxnet.NDArray" - - // TODO: Currently we do not add place holder for NDArray - // Example: an NDArray operator like the following format - // nd.foo(arg1: NDArray(required), arg2: NDArray(Optional), arg3: NDArray(Optional) - // If we place nd.foo(arg1, arg3 = arg3), do we need to add place holder for arg2? - // What it should be? - val base = - if (ndarrayarg.argType.equals(returnType)) { - s"args += $currArgName" - } else if (ndarrayarg.argType.equals(s"Array[$returnType]")){ - s"args ++= $currArgName" - } else { - "map(\"" + ndarrayarg.argName + "\") = " + currArgName - } - impl.append( - if (ndarrayarg.isOptional) s"if (!$currArgName.isEmpty) $base.get" - else base - ) - }) - // add default out parameter - argDef += "out : Option[NDArray] = None" - impl += "if (!out.isEmpty) map(\"out\") = out.get" - // scalastyle:off - impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)" - // scalastyle:on - // Combine and build the function string - val returnType = "org.apache.mxnet.NDArrayFuncReturn" - var finalStr = s"def ${ndarrayfunction.name}" - finalStr += s" (${argDef.mkString(",")}) : $returnType" - finalStr += s" = {${impl.mkString("\n")}}" - c.parse(finalStr).asInstanceOf[DefDef] + val functions = functionsToGenerate(isSymbol = false, isContrib) + + val functionDefs = functions.flatMap { NDArrayfunction => + val funcName = NDArrayfunction.name + val termName = TermName(funcName) + Seq( + // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*) + q""" + def $termName(kwargs: Map[String, Any] = null)(args: Any*) = { + genericNDArrayFunctionInvoke($funcName, args, kwargs) + } + """.asInstanceOf[DefDef], + // e.g def transpose(args: Any*) + q""" + def $termName(args: Any*) = { + genericNDArrayFunctionInvoke($funcName, args, null) + } + """.asInstanceOf[DefDef] + ) } - structGeneration(c)(functionDefs, annottees : _*) + structGeneration(c)(functionDefs, annottees: _*) } +} - private def structGeneration(c: blackbox.Context) - (funcDef : List[c.universe.DefDef], annottees: c.Expr[Any]*) - : c.Expr[Any] = { +private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase { + + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ - val inputs = annottees.map(_.tree).toList - // pattern match on the inputs - val modDefs = inputs map { - case ClassDef(mods, name, something, template) => - val q = template match { - case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ funcDef) - case ex => - throw new IllegalArgumentException(s"Invalid template: $ex") - } - ClassDef(mods, name, something, q) - case ModuleDef(mods, name, template) => - val q = template match { - case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ funcDef) - case ex => - throw new IllegalArgumentException(s"Invalid template: $ex") - } - ModuleDef(mods, name, q) - case ex => - throw new IllegalArgumentException(s"Invalid macro input: $ex") + val isContrib: Boolean = c.prefix.tree match { + case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b)) } - // wrap the result up in an Expr, and return it - val result = c.Expr(Block(modDefs, Literal(Constant()))) - result + + val functions = typeSafeFunctionsToGenerate(isSymbol = false, isContrib) + + val functionDefs = functions.map(f => buildTypedFunction(c)(f)) + structGeneration(c)(functionDefs, annottees: _*) } + protected def buildTypedFunction(c: blackbox.Context) + (function: Func): c.universe.DefDef = { + import c.universe._ + val returnType = "org.apache.mxnet.NDArrayFuncReturn" + val ndarrayType = "org.apache.mxnet.NDArray" + // Construct argument field + val argDef = ListBuffer[String]() + argDef ++= typedFunctionCommonArgDef(function) + argDef += "out : Option[NDArray] = None" - // List and add all the atomic symbol functions to current module. - private def initNDArrayModule(): List[NDArrayFunction] = { - val opNames = ListBuffer.empty[String] - _LIB.mxListAllOpNames(opNames) - opNames.map(opName => { - val opHandle = new RefLong - _LIB.nnGetOpHandle(opName, opHandle) - makeNDArrayFunction(opHandle.value, opName) - }).toList - } + // Construct Implementation field + var impl = ListBuffer[String]() + impl += "val map = scala.collection.mutable.Map[String, Any]()" + impl += s"val args = scala.collection.mutable.ArrayBuffer.empty[$ndarrayType]" - // Create an atomic symbol function by handle and function name. - private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String) - : NDArrayFunction = { - val name = new RefString - val desc = new RefString - val keyVarNumArgs = new RefString - val numArgs = new RefInt - val argNames = ListBuffer.empty[String] - val argTypes = ListBuffer.empty[String] - val argDescs = ListBuffer.empty[String] - - _LIB.mxSymbolGetAtomicSymbolInfo( - handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs) - val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, argDescs) - val extraDoc: String = if (keyVarNumArgs.value != null && keyVarNumArgs.value.length > 0) { - s"This function support variable length of positional input (${keyVarNumArgs.value})." - } else { - "" - } - val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})" - val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n" - // scalastyle:off println - if (System.getenv("MXNET4J_PRINT_OP_DEF") != null - && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") { - println("NDArray function definition:\n" + docStr) - } - // scalastyle:on println - val argList = argNames zip argTypes map { case (argName, argType) => - val typeAndOption = - CToScalaUtils.argumentCleaner(argName, argType, "org.apache.mxnet.NDArray") - new NDArrayArg(argName, typeAndOption._1, typeAndOption._2) + // NDArray arg implementation + impl ++= function.listOfArgs.map { arg => + if (arg.argType.equals(s"Array[$ndarrayType]")) { + s"args ++= ${arg.safeArgName}" + } else { + val base = + if (arg.argType.equals(ndarrayType)) { + // ndarrays go to args + s"args += ${arg.safeArgName}" + } else { + // other types go to kwargs + s"""map("${arg.argName}") = ${arg.safeArgName}""" + } + if (arg.isOptional) s"if (!${arg.safeArgName}.isEmpty) $base.get" + else base + } } - new NDArrayFunction(aliasName, argList.toList) + + impl += + s"""if (!out.isEmpty) map("out") = out.get + |org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke( + | "${function.name}", args.toSeq, map.toMap) + """.stripMargin + + // Combine and build the function string + val finalStr = + s"""def ${function.name} + | (${argDef.mkString(",")}) : $returnType + | = {${impl.mkString("\n")}} + """.stripMargin + + c.parse(finalStr).asInstanceOf[DefDef] } } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 42aa11781d8f..ab864e1ef195 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -21,222 +21,106 @@ import scala.annotation.StaticAnnotation import scala.collection.mutable.ListBuffer import scala.language.experimental.macros import scala.reflect.macros.blackbox -import org.apache.mxnet.init.Base._ -import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils} private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends StaticAnnotation { - private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addDefs + private[mxnet] def macroTransform(annottees: Any*) = macro SymbolMacro.addDefs } private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends StaticAnnotation { - private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.typeSafeAPIDefs + private[mxnet] def macroTransform(annottees: Any*) = macro TypedSymbolAPIMacro.typeSafeAPIDefs } -private[mxnet] object SymbolImplMacros { - case class SymbolArg(argName: String, argType: String, isOptional : Boolean) - case class SymbolFunction(name: String, listOfArgs: List[SymbolArg]) +private[mxnet] object SymbolMacro extends GeneratorBase { - // scalastyle:off havetype - def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { - impl(c)(annottees: _*) - } - def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { - typedAPIImpl(c)(annottees: _*) - } - // scalastyle:on havetype - - private val symbolFunctions: List[SymbolFunction] = initSymbolModule() - - /** - * Implementation for fixed input API structure - */ - private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { + def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ - val isContrib: Boolean = c.prefix.tree match { case q"new AddSymbolFunctions($b)" => c.eval[Boolean](c.Expr(b)) } - val newSymbolFunctions = { - if (isContrib) symbolFunctions.filter( - func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) - else symbolFunctions.filter(!_.name.startsWith("_")) - } + impl(c)(isContrib, annottees: _*) + } + + private def impl(c: blackbox.Context) + (isContrib: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = { + import c.universe._ + val functions = functionsToGenerate(isSymbol = false, isContrib) - val functionDefs = newSymbolFunctions map { symbolfunction => - val funcName = symbolfunction.name - val tName = TermName(funcName) - q""" + val functionDefs = functions.map { symbolfunction => + val funcName = symbolfunction.name + val tName = TermName(funcName) + q""" def $tName(name : String = null, attr : Map[String, String] = null) - (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null) - : org.apache.mxnet.Symbol = { - createSymbolGeneral($funcName,name,attr,args,kwargs) - } + (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null) + : org.apache.mxnet.Symbol = { + createSymbolGeneral($funcName,name,attr,args,kwargs) + } """.asInstanceOf[DefDef] - } + } - structGeneration(c)(functionDefs, annottees : _*) + structGeneration(c)(functionDefs, annottees: _*) } +} - /** - * Implementation for Dynamic typed API Symbol.api. - */ - private def typedAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = { - import c.universe._ +private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase { + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { + import c.universe._ val isContrib: Boolean = c.prefix.tree match { case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b)) } - // Defines Operators that should not generated - val notGenerated = Set("Custom") - - // TODO: Put Symbol.api.foo --> Stable APIs - // Symbol.contrib.bar--> Contrib APIs - val newSymbolFunctions = { - if (isContrib) symbolFunctions.filter( - func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) - else symbolFunctions.filter(!_.name.startsWith("_")) - }.filterNot(ele => notGenerated.contains(ele.name)) - - val functionDefs = newSymbolFunctions map { symbolfunction => - - // Construct argument field - var argDef = ListBuffer[String]() - // Construct Implementation field - var impl = ListBuffer[String]() - impl += "val map = scala.collection.mutable.Map[String, Any]()" - impl += "var args = Seq[org.apache.mxnet.Symbol]()" - symbolfunction.listOfArgs.foreach({ symbolarg => - // var is a special word used to define variable in Scala, - // need to changed to something else in order to make it work - val currArgName = symbolarg.argName match { - case "var" => "vari" - case "type" => "typeOf" - case default => symbolarg.argName - } - if (symbolarg.isOptional) { - argDef += s"${currArgName} : Option[${symbolarg.argType}] = None" - } - else { - argDef += s"${currArgName} : ${symbolarg.argType}" - } - // Symbol arg implementation - val returnType = "org.apache.mxnet.Symbol" - val base = - if (symbolarg.argType.equals(s"Array[$returnType]")) { - if (symbolarg.isOptional) s"if (!$currArgName.isEmpty) args = $currArgName.get.toSeq" - else s"args = $currArgName.toSeq" - } else { - if (symbolarg.isOptional) { - // scalastyle:off - s"if (!$currArgName.isEmpty) map(" + "\"" + symbolarg.argName + "\"" + s") = $currArgName.get" - // scalastyle:on - } - else "map(\"" + symbolarg.argName + "\"" + s") = $currArgName" - } + val functions = typeSafeFunctionsToGenerate(isSymbol = true, isContrib) - impl += base - }) - argDef += "name : String = null" - argDef += "attr : Map[String, String] = null" - // scalastyle:off - // TODO: Seq() here allows user to place Symbols rather than normal arguments to run, need to fix if old API deprecated - impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, args, map.toMap)" - // scalastyle:on - // Combine and build the function string - val returnType = "org.apache.mxnet.Symbol" - var finalStr = s"def ${symbolfunction.name}" - finalStr += s" (${argDef.mkString(",")}) : $returnType" - finalStr += s" = {${impl.mkString("\n")}}" - c.parse(finalStr).asInstanceOf[DefDef] - } - structGeneration(c)(functionDefs, annottees : _*) + val functionDefs = functions.map(f => buildTypedFunction(c)(f)) + structGeneration(c)(functionDefs, annottees: _*) } - /** - * Generate class structure for all function APIs - * @param c - * @param funcDef DefDef type of function definitions - * @param annottees - * @return - */ - private def structGeneration(c: blackbox.Context) - (funcDef : List[c.universe.DefDef], annottees: c.Expr[Any]*) - : c.Expr[Any] = { + protected def buildTypedFunction(c: blackbox.Context) + (function: Func): c.universe.DefDef = { import c.universe._ - val inputs = annottees.map(_.tree).toList - // pattern match on the inputs - val modDefs = inputs map { - case ClassDef(mods, name, something, template) => - val q = template match { - case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ funcDef) - case ex => - throw new IllegalArgumentException(s"Invalid template: $ex") - } - ClassDef(mods, name, something, q) - case ModuleDef(mods, name, template) => - val q = template match { - case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ funcDef) - case ex => - throw new IllegalArgumentException(s"Invalid template: $ex") - } - ModuleDef(mods, name, q) - case ex => - throw new IllegalArgumentException(s"Invalid macro input: $ex") - } - // wrap the result up in an Expr, and return it - val result = c.Expr(Block(modDefs, Literal(Constant()))) - result - } - // List and add all the atomic symbol functions to current module. - private def initSymbolModule(): List[SymbolFunction] = { - val opNames = ListBuffer.empty[String] - _LIB.mxListAllOpNames(opNames) - // TODO: Add '_linalg_', '_sparse_', '_image_' support - opNames.map(opName => { - val opHandle = new RefLong - _LIB.nnGetOpHandle(opName, opHandle) - makeAtomicSymbolFunction(opHandle.value, opName) - }).toList - } + val returnType = "org.apache.mxnet.Symbol" + val symbolType = "org.apache.mxnet.Symbol" - // Create an atomic symbol function by handle and function name. - private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String) - : SymbolFunction = { - val name = new RefString - val desc = new RefString - val keyVarNumArgs = new RefString - val numArgs = new RefInt - val argNames = ListBuffer.empty[String] - val argTypes = ListBuffer.empty[String] - val argDescs = ListBuffer.empty[String] - - _LIB.mxSymbolGetAtomicSymbolInfo( - handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs) - val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, argDescs) - val extraDoc: String = if (keyVarNumArgs.value != null && keyVarNumArgs.value.length > 0) { - s"This function support variable length of positional input (${keyVarNumArgs.value})." + // Construct argument field + val argDef = ListBuffer[String]() + argDef ++= typedFunctionCommonArgDef(function) + argDef += "name : String = null" + argDef += "attr : Map[String, String] = null" + + // Construct Implementation field + val impl = ListBuffer[String]() + impl += "val map = scala.collection.mutable.Map[String, Any]()" + impl += s"var args = scala.collection.Seq[$symbolType]()" + + // Symbol arg implementation + impl ++= function.listOfArgs.map { arg => + if (arg.argType.equals(s"Array[$symbolType]")) { + s"if (!${arg.safeArgName}.isEmpty) args = ${arg.safeArgName}.toSeq" } else { - "" + // all go in kwargs + if (arg.isOptional) { + s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") = ${arg.safeArgName}.get""" + } else { + s"""map("${arg.argName}") = ${arg.safeArgName}""" + } } - val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})" - val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n" - // scalastyle:off println - if (System.getenv("MXNET4J_PRINT_OP_DEF") != null - && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") { - println("Symbol function definition:\n" + docStr) } - // scalastyle:on println - val argList = argNames zip argTypes map { case (argName, argType) => - val typeAndOption = - CToScalaUtils.argumentCleaner(argName, argType, "org.apache.mxnet.Symbol") - new SymbolArg(argName, typeAndOption._1, typeAndOption._2) - } - new SymbolFunction(aliasName, argList.toList) + + impl += + s"""org.apache.mxnet.Symbol.createSymbolGeneral( + | "${function.name}", name, attr, args, map.toMap) + """.stripMargin + + // Combine and build the function string + val finalStr = + s"""def ${function.name} + | (${argDef.mkString(",")}) : $returnType + | = {${impl.mkString("\n")}} + """.stripMargin + + c.parse(finalStr).asInstanceOf[DefDef] } }