diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala index 52e26efb41f1..b54ecc05818e 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala @@ -38,6 +38,7 @@ object Image { * @param flag Convert decoded image to grayscale (0) or color (1). * @param to_rgb Whether to convert decoded image * to mxnet's default RGB format (instead of opencv's default BGR). + * @param out NDArray to store the output * @return NDArray in HWC format with DType [[DType.UInt8]] */ def imDecode(buf: Array[Byte], flag: Int, @@ -57,6 +58,10 @@ object Image { /** * Same imageDecode with InputStream * @param inputStream the inputStream of the image + * @param flag Convert decoded image to grayscale (0) or color (1). + * @param to_rgb Whether to convert decoded image + * to mxnet's default RGB format (instead of opencv's default BGR). + * @param out NDArray to store the output * @return NDArray in HWC format with DType [[DType.UInt8]] */ def imDecode(inputStream: InputStream, flag: Int = 1, @@ -79,6 +84,7 @@ object Image { * @param flag Convert decoded image to grayscale (0) or color (1). * @param to_rgb Whether to convert decoded image to mxnet's default RGB format * (instead of opencv's default BGR). + * @param out NDArray to store the output * @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]] */ def imRead(filename: String, flag: Option[Int] = None, @@ -99,6 +105,7 @@ object Image { * @param w Width of resized image. * @param h Height of resized image. * @param interp Interpolation method (default=cv2.INTER_LINEAR). + * @param out NDArray to store the output * @return org.apache.mxnet.NDArray */ def imResize(src: org.apache.mxnet.NDArray, w: Int, h: Int, @@ -124,6 +131,7 @@ object Image { * @param typeOf Filling type (default=cv2.BORDER_CONSTANT). * @param value (Deprecated! Use ``values`` instead.) Fill with single value. * @param values Fill with value(RGB[A] or gray), up to 4 channels. + * @param out NDArray to store the output * @return org.apache.mxnet.NDArray */ def copyMakeBorder(src: org.apache.mxnet.NDArray, top: Int, bot: Int, diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 849f4566f528..3764f5a4a040 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -36,6 +36,11 @@ import scala.util.Try */ @AddNDArrayFunctions(false) object NDArray extends NDArrayBase { + /** + * method to convert NDArrayFunctionReturn to NDArray + * @param ret the returned NDArray list + * @return NDArray result + */ implicit def getFirstResult(ret: NDArrayFuncReturn): NDArray = ret(0) private val logger = LoggerFactory.getLogger(classOf[NDArray]) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala index f72223d1e4da..57a485083f20 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala @@ -36,6 +36,12 @@ object Image { org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None) } + /** + * Decode image with OpenCV. + * Note: return image in RGB by default, instead of OpenCV's default BGR. + * @param buf Buffer containing binary encoded image + * @return NDArray in HWC format with DType [[DType.UInt8]] + */ def imDecode(buf: Array[Byte]): NDArray = { imDecode(buf, 1, true) } @@ -52,6 +58,12 @@ object Image { org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None) } + /** + * Same imageDecode with InputStream + * + * @param inputStream the inputStream of the image + * @return NDArray in HWC format with DType [[DType.UInt8]] + */ def imDecode(inputStream: InputStream): NDArray = { imDecode(inputStream, 1, true) } @@ -69,6 +81,12 @@ object Image { org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None) } + /** + * Read and decode image with OpenCV. + * Note: return image in RGB by default, instead of OpenCV's default BGR. + * @param filename Name of the image file to be loaded. + * @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]] + */ def imRead(filename: String): NDArray = { imRead(filename, 1, true) } @@ -86,6 +104,13 @@ object Image { org.apache.mxnet.Image.imResize(src, w, h, interpVal, None) } + /** + * Resize image with OpenCV. + * @param src source image in NDArray + * @param w Width of resized image. + * @param h Height of resized image. + * @return org.apache.mxnet.NDArray + */ def imResize(src: NDArray, w: Int, h: Int): NDArray = { imResize(src, w, h, null) } diff --git a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala index b5a6286af1b6..e3fa28fb2a59 100644 --- a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala +++ b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala @@ -29,17 +29,41 @@ object Base { class RefFloat(val value: Float = 0) class RefString(val value: String = null) + /** + * This C Pointer Address point to the + * actual memory piece created in MXNet Engine + */ type CPtrAddress = Long + /** + * NDArrayHandle is the C pointer to + * the NDArray + */ type NDArrayHandle = CPtrAddress + /** + * FunctionHandle is the C pointer to + * the ids of the operators + */ type FunctionHandle = CPtrAddress + /** + * KVStoreHandle is the C pointer to + * the KVStore + */ type KVStoreHandle = CPtrAddress + /** + * ExecutorHandle is the C pointer to + * the Executor + */ type ExecutorHandle = CPtrAddress + /** + * SymbolHandle is the C pointer to + * the Symbol + */ type SymbolHandle = CPtrAddress @throws(classOf[UnsatisfiedLinkError]) private def tryLoadInitLibrary(): Unit = { - var userDir : File = new File(System.getProperty("user.dir")) + val userDir : File = new File(System.getProperty("user.dir")) var nativeDir : File = new File(userDir, "init-native") if (!nativeDir.exists()) { nativeDir = new File(userDir.getParent, "init-native") @@ -50,7 +74,6 @@ object Base { val baseDir = nativeDir.getAbsolutePath val os = System.getProperty("os.name") - // ref: http://lopica.sourceforge.net/os.html if (os.startsWith("Linux")) { System.load(s"$baseDir/target/libmxnet-init-scala.so") } else if (os.startsWith("Mac")) { diff --git a/scala-package/init/src/main/scala/org/apache/mxnet/init/LibInfo.scala b/scala-package/init/src/main/scala/org/apache/mxnet/init/LibInfo.scala index 7bd0c701f872..c813d449f652 100644 --- a/scala-package/init/src/main/scala/org/apache/mxnet/init/LibInfo.scala +++ b/scala-package/init/src/main/scala/org/apache/mxnet/init/LibInfo.scala @@ -22,7 +22,25 @@ import org.apache.mxnet.init.Base._ import scala.collection.mutable.ListBuffer class LibInfo { + /** + * Get the list of the symbol ids + * @param symbolList Pass in an empty ListBuffer and obtain a list of operator IDs + * @return Callback result + */ @native def mxSymbolListAtomicSymbolCreators(symbolList: ListBuffer[SymbolHandle]): Int + + /** + * Get the detailed information of an operator + * @param handle The ID of the operator + * @param name Name of the operator + * @param desc Description of the operator + * @param numArgs Number of arguments + * @param argNames Argument names + * @param argTypes Argument types + * @param argDescs Argument descriptions + * @param keyVarNumArgs Kwargs number + * @return Callback result + */ @native def mxSymbolGetAtomicSymbolInfo(handle: SymbolHandle, name: RefString, desc: RefString, @@ -31,6 +49,18 @@ class LibInfo { argTypes: ListBuffer[String], argDescs: ListBuffer[String], keyVarNumArgs: RefString): Int + /** + * Get the name list of all operators + * @param names Names of all operator + * @return Callback result + */ @native def mxListAllOpNames(names: ListBuffer[String]): Int + + /** + * Get operator ID from its name + * @param opName Operator name + * @param opHandle Operator ID + * @return Callback result + */ @native def nnGetOpHandle(opName: String, opHandle: RefLong): Int } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index ede16f73d2a1..a5102d6624ef 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -23,12 +23,16 @@ import java.security.MessageDigest import scala.collection.mutable.ListBuffer /** - * This object will generate the Scala documentation of the new Scala API - * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala + * This object will generate the Scala documentation of the Scala/Java APIs * The code will be executed during Macros stage and file live in Core stage */ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { + /** + * Main method used to generate code and write to files + * A hash check placed at the end to verify changes + * @param args Input args + */ def main(args: Array[String]): Unit = { val FILE_PATH = args(0) val hashCollector = ListBuffer[String]() @@ -42,6 +46,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { val finalHash = hashCollector.mkString("\n") } + /** + * Generate MD5 result from an input string + * Encoded in UTF-8 + * @param input The input string + * @return A MD5 value from the string + */ def MD5Generator(input: String): String = { val md = MessageDigest.getInstance("MD5") md.update(input.getBytes("UTF-8")) @@ -49,6 +59,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest) } + /** + * Type-safe class body generation for NDArray/Symbol + * @param FILE_PATH File path write the file to + * @param isSymbol Check if write the Symbol API, NDArray otherwise + * @return MD5 String + */ def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = { val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false) .map { func => @@ -65,6 +81,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { generated) } + /** + * Generate the Random classes for Symbol/NDArray + * @param FILE_PATH File path write the file to + * @param isSymbol Check if write the Symbol API, NDArray otherwise + * @return MD5 String + */ def typeSafeRandomClassGen(FILE_PATH: String, isSymbol: Boolean): String = { val generated = typeSafeRandomFunctionsToGenerate(isSymbol) .map { func => @@ -83,6 +105,16 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { generated) } + /** + * Non Type-safe interface of Scala Symbol/NDArray + * It includes class definition : e.g class SymbolBase + * and function definitions : e.g def softmax(...)(...)(...) : NDArray + * Users can directly use the api by calling NDArray. + * It support both positional input or Map input + * @param FILE_PATH File path write the file to + * @param isSymbol Check if write the Symbol API, NDArray otherwise + * @return MD5 String + */ def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = { val absFuncs = functionsToGenerate(isSymbol, isContrib = false) .map { func => @@ -112,7 +144,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { absFuncs) } - def javaClassGen(filePath : String) : String = { + /** + * Type-safe interface of Java NDArray + * @param FILE_PATH File path write the file to + * @return MD5 String + */ + def javaClassGen(FILE_PATH : String) : String = { val notGenerated = Set("Custom") val absClassFunctions = functionsToGenerate(false, false, true) val absFuncs = absClassFunctions.filterNot(ele => notGenerated.contains(ele.name)) @@ -133,13 +170,19 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { val packageName = "NDArrayBase" val packageDef = "package org.apache.mxnet.javaapi" writeFile( - filePath + "javaapi/", + FILE_PATH + "javaapi/", packageDef, packageName, "import org.apache.mxnet.annotation.Experimental", absFuncs) } + /** + * Generate Scala docs from the function description + * @param func The function case class + * @param withParam Whether to generate param field + * @return A formatted string for the function description + */ def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = { def fixDesc(desc: String): String = { var curDesc = desc @@ -173,6 +216,14 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { } } + /** + * Generate the function interface + * e.g: def softmax(data: NDArray, name ...): NDArrayFunctionReturn + * @param func The function case class + * @param isSymbol Check if generate Symbol function, NDArray otherwise + * @param typeParameter Type param specifically used in Random Module + * @return Formatted string for the function + */ def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter: String = ""): String = { val argDef = ListBuffer[String]() @@ -192,6 +243,11 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { |def ${func.name}$typeParameter (${argDef.mkString(", ")}): $returnType""".stripMargin } + /** + * Generate Java function interface + * @param func The function case class + * @return A formatted string for the function + */ def generateJavaAPISignature(func : Func) : String = { val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2 var argDef = ListBuffer[String]() @@ -250,6 +306,15 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { } } + /** + * Write the formatted string to file + * @param FILE_PATH Location of the file writes to + * @param packageDef Package definition + * @param className Class name + * @param imports Packages need to import + * @param absFuncs All formatted functions + * @return A MD5 string + */ def writeFile(FILE_PATH: String, packageDef: String, className: String, imports: String, absFuncs: Seq[String]): String = { diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala index b2033f529c65..e3fdeabd54e9 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala @@ -17,16 +17,20 @@ package org.apache.mxnet -import org.apache.mxnet.init.Base.{RefInt, RefLong, RefString, _LIB} -import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils} +import org.apache.mxnet.init.Base.{CPtrAddress, RefInt, RefLong, RefString, _LIB} +import org.apache.mxnet.utils.CToScalaUtils import scala.collection.mutable.ListBuffer import scala.reflect.macros.blackbox private[mxnet] abstract class GeneratorBase { - type Handle = Long case class Arg(argName: String, argType: String, argDesc: String, isOptional: Boolean) { + /** + * Filter the arg name with the Scala keyword that are not allow to use as arg name, + * such as var and type listed in here. This is due to the diff between C and Scala + * @return argname that works in Scala + */ def safeArgName: String = argName match { case "var" => "vari" case "type" => "typeOf" @@ -36,6 +40,14 @@ private[mxnet] abstract class GeneratorBase { case class Func(name: String, desc: String, listOfArgs: List[Arg], returnType: String) + /** + * Non Type-safe function generation method + * This method will filter all "_" functions + * @param isSymbol Check if generate the Symbol method + * @param isContrib Check if generate the contrib method + * @param isJava Check if generate Corresponding Java method + * @return List of functions + */ def functionsToGenerate(isSymbol: Boolean, isContrib: Boolean, isJava: Boolean = false): List[Func] = { val l = getBackEndFunctions(isSymbol, isJava) @@ -46,7 +58,12 @@ private[mxnet] abstract class GeneratorBase { } } - // filter the operators to generate in the type-safe Symbol.api and NDArray.api + /** + * Filter the operators to generate in the type-safe Symbol.api and NDArray.api + * @param isSymbol Check if generate the Symbol method + * @param isContrib Check if generate the contrib method + * @return List of functions + */ protected def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = { // Operators that should not be generated val notGenerated = Set("Custom") @@ -60,6 +77,12 @@ private[mxnet] abstract class GeneratorBase { res.filterNot(ele => notGenerated.contains(ele.name)) } + /** + * Extract and format the functions obtained from C API + * @param isSymbol Check if generate for Symbol + * @param isJava Check if extracting in Java format + * @return List of functions + */ protected def getBackEndFunctions(isSymbol: Boolean, isJava: Boolean = false): List[Func] = { val opNames = ListBuffer.empty[String] _LIB.mxListAllOpNames(opNames) @@ -70,7 +93,7 @@ private[mxnet] abstract class GeneratorBase { }).toList } - private def makeAtomicFunction(handle: Handle, aliasName: String, + private def makeAtomicFunction(handle: CPtrAddress, aliasName: String, isSymbol: Boolean, isJava: Boolean): Func = { val name = new RefString val desc = new RefString @@ -82,14 +105,11 @@ private[mxnet] abstract class GeneratorBase { _LIB.mxSymbolGetAtomicSymbolInfo( handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs) - val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, argDescs) val extraDoc: String = if (keyVarNumArgs.value != null && keyVarNumArgs.value.length > 0) { s"This function support variable length of positional input (${keyVarNumArgs.value})." } else { "" } - val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})" - val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n" val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) => val family = if (isJava) "org.apache.mxnet.javaapi.NDArray" @@ -109,10 +129,10 @@ private[mxnet] abstract class GeneratorBase { /** * Generate class structure for all function APIs * - * @param c + * @param c Context used for generation * @param funcDef DefDef type of function definitions - * @param annottees - * @return + * @param annottees Annottees used to define Class or Module + * @return Expr used for code generation */ protected def structGeneration(c: blackbox.Context) (funcDef: List[c.universe.DefDef], annottees: c.Expr[Any]*) @@ -145,7 +165,11 @@ private[mxnet] abstract class GeneratorBase { result } - // build function argument definition, with optionality, and safe names + /** + * Build function argument definition, with optionality, and safe names + * @param func Functions + * @return List of string representing the functions interface + */ protected def typedFunctionCommonArgDef(func: Func): List[String] = { func.listOfArgs.map(arg => if (arg.isOptional) { @@ -167,14 +191,23 @@ private[mxnet] abstract class GeneratorBase { private[mxnet] trait RandomHelpers { self: GeneratorBase => - // a generic type spec used in Symbol.random and NDArray.random modules +/** + * A generic type spec used in Symbol.random and NDArray.random modules + * @param isSymbol Check if generate for Symbol + * @param fullPackageSpec Check if leave the full name of the classTag + * @return A formatted string for random Symbol/NDArray + */ protected def randomGenericTypeSpec(isSymbol: Boolean, fullPackageSpec: Boolean): String = { val classTag = if (fullPackageSpec) "scala.reflect.ClassTag" else "ClassTag" if (isSymbol) s"[T: SymbolOrScalar : $classTag]" else s"[T: NDArrayOrScalar : $classTag]" } - // filter the operators to generate in the type-safe Symbol.random and NDArray.random +/** + * Filter the operators to generate in the type-safe Symbol.random and NDArray.random + * @param isSymbol Check if generate Symbol functions + * @return List of functions + */ protected def typeSafeRandomFunctionsToGenerate(isSymbol: Boolean): List[Func] = { getBackEndFunctions(isSymbol) .filter(f => f.name.startsWith("_sample_") || f.name.startsWith("_random_")) @@ -206,16 +239,24 @@ private[mxnet] trait RandomHelpers { ) } - // hacks to manage the fact that random_normal and sample_normal have - // non-consistent parameter naming in the back-end - // this first one, merge loc/scale and mu/sigma + /** + * Hacks to manage the fact that random_normal and sample_normal have + * non-consistent parameter naming in the back-end + * this first one, merge loc/scale and mu/sigma + * @param arg Argument need to modify + * @return Arg case class with clean arg names + */ protected def hackNormalFunc(arg: Arg): Arg = { if (arg.argName == "loc") arg.copy(argName = "mu") else if (arg.argName == "scale") arg.copy(argName = "sigma") else arg } - // this second one reverts this merge prior to back-end call + /** + * This second one reverts this merge prior to back-end call + * @param func Function case class + * @return A string contains the implementation of random args + */ protected def unhackNormalFunc(func: Func): String = { if (func.name.equals("normal")) { s"""if(target.equals("random_normal")) { diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index f5b8bce11cf5..c9c10f50c01f 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -22,16 +22,30 @@ import scala.language.experimental.macros import scala.reflect.macros.blackbox private[mxnet] class AddNDArrayFunctions(isContrib: Boolean) extends StaticAnnotation { - private[mxnet] def macroTransform(annottees: Any*): Any = macro NDArrayMacro.addDefs +/** + * Generate non-typesafe method for NDArray operations + * @param annottees Annottees used to define Class or Module + * @return Generated code for injection + */ + private[mxnet] def macroTransform(annottees: Any*) = macro NDArrayMacro.addDefs } private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation { - private[mxnet] def macroTransform(annottees: Any*): Any = - macro TypedNDArrayAPIMacro.typeSafeAPIDefs +/** + * Generate typesafe method for NDArray operations + * @param annottees Annottees used to define Class or Module + * @return Generated code for injection + */ + private[mxnet] def macroTransform(annottees: Any*) = macro TypedNDArrayAPIMacro.typeSafeAPIDefs } private[mxnet] class AddNDArrayRandomAPIs(isContrib: Boolean) extends StaticAnnotation { - private[mxnet] def macroTransform(annottees: Any*): Any = +/** + * Generate typesafe method for Random Symbol + * @param annottees Annottees used to define Class or Module + * @return Generated code for injection + */ + private[mxnet] def macroTransform(annottees: Any*) = macro TypedNDArrayRandomAPIMacro.typeSafeAPIDefs } @@ -39,8 +53,13 @@ private[mxnet] class AddNDArrayRandomAPIs(isContrib: Boolean) extends StaticAnno * For non-typed NDArray API */ private[mxnet] object NDArrayMacro extends GeneratorBase { - - def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Nothing] = { + /** + * Methods that check the ``isContrib`` and call code generation + * @param c Context used for code gen + * @param annottees Annottees used to define Class or Module + * @return Generated code for injection + */ + def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ val isContrib: Boolean = c.prefix.tree match { case q"new AddNDArrayFunctions($b)" => c.eval[Boolean](c.Expr(b)) @@ -82,8 +101,13 @@ private[mxnet] object NDArrayMacro extends GeneratorBase { * NDArray.api code generation */ private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase { - - def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Nothing] = { + /** + * Methods that check the ``isContrib`` and call code generation + * @param c Context used for code gen + * @param annottees Annottees used to define Class or Module + * @return Generated code for injection + */ + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ val isContrib: Boolean = c.prefix.tree match { case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b)) @@ -95,6 +119,12 @@ private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase { structGeneration(c)(functionDefs, annottees: _*) } + /** + * Methods that construct the code and build the syntax tree + * @param c Context used for code gen + * @param function Case class that store all information of the single function + * @return Generated syntax tree + */ protected def buildTypedFunction(c: blackbox.Context) (function: Func): c.universe.DefDef = { import c.universe._ @@ -148,8 +178,13 @@ private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase { */ private[mxnet] object TypedNDArrayRandomAPIMacro extends GeneratorBase with RandomHelpers { - - def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Nothing] = { + /** + * methods that check the ``isContrib`` and call code generation + * @param c Context used for code gen + * @param annottees annottees used to define Class or Module + * @return generated code for injection + */ + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { // Note: no contrib managed in this module val functionDefs = typeSafeRandomFunctionsToGenerate(isSymbol = false) @@ -158,6 +193,12 @@ private[mxnet] object TypedNDArrayRandomAPIMacro extends GeneratorBase structGeneration(c)(functionDefs, annottees: _*) } + /** + * Methods that construct the code and build the syntax tree + * @param c Context used for code gen + * @param function Case class that store all information of the single function + * @return Generated syntax tree + */ protected def buildTypedFunction(c: blackbox.Context) (function: Func): c.universe.DefDef = { import c.universe._ diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 06b567c3d2d4..1a0405cfd63d 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -23,16 +23,30 @@ import scala.language.experimental.macros import scala.reflect.macros.blackbox private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends StaticAnnotation { - private[mxnet] def macroTransform(annottees: Any*): Any = macro SymbolMacro.addDefs +/** + * Generate non-typesafe method for Symbol operations + * @param annottees Annottees used to define Class or Module + * @return Generated code for injection + */ + private[mxnet] def macroTransform(annottees: Any*) = macro SymbolMacro.addDefs } private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends StaticAnnotation { - private[mxnet] def macroTransform(annottees: Any*): Any = - macro TypedSymbolAPIMacro.typeSafeAPIDefs +/** + * Generate typesafe method for Symbol + * @param annottees Annottees used to define Class or Module + * @return Generated code for injection + */ + private[mxnet] def macroTransform(annottees: Any*) = macro TypedSymbolAPIMacro.typeSafeAPIDefs } private[mxnet] class AddSymbolRandomAPIs(isContrib: Boolean) extends StaticAnnotation { - private[mxnet] def macroTransform(annottees: Any*): Any = +/** + * Generate typesafe method for Random Symbol + * @param annottees Annottees used to define Class or Module + * @return Generated code for injection + */ + private[mxnet] def macroTransform(annottees: Any*) = macro TypedSymbolRandomAPIMacro.typeSafeAPIDefs } @@ -41,7 +55,13 @@ private[mxnet] class AddSymbolRandomAPIs(isContrib: Boolean) extends StaticAnnot */ private[mxnet] object SymbolMacro extends GeneratorBase { - def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Nothing] = { + /** + * Methods that check the ``isContrib`` and call code generation + * @param c Context used for code gen + * @param annottees Annottees used to define Class or Module + * @return Generated code for injection + */ + def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ val isContrib: Boolean = c.prefix.tree match { case q"new AddSymbolFunctions($b)" => c.eval[Boolean](c.Expr(b)) @@ -77,7 +97,13 @@ private[mxnet] object SymbolMacro extends GeneratorBase { */ private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase { - def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Nothing] = { + /** + * Methods that check the ``isContrib`` and call code generation + * @param c Context used for code gen + * @param annottees Annottees used to define Class or Module + * @return Generated code for injection + */ + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ val isContrib: Boolean = c.prefix.tree match { case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b)) @@ -89,6 +115,12 @@ private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase { structGeneration(c)(functionDefs, annottees: _*) } + /** + * Methods that construct the code and build the syntax tree + * @param c Context used for code gen + * @param function Case class that store all information of the single function + * @return Generated syntax tree + */ protected def buildTypedFunction(c: blackbox.Context) (function: Func): c.universe.DefDef = { import c.universe._ @@ -141,13 +173,25 @@ private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase { private[mxnet] object TypedSymbolRandomAPIMacro extends GeneratorBase with RandomHelpers { - def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Nothing] = { + /** + * Methods that check the ``isContrib`` and call code generation + * @param c Context used for code gen + * @param annottees Annottees used to define Class or Module + * @return Generated code for injection + */ + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { val functionDefs = typeSafeRandomFunctionsToGenerate(isSymbol = true) .map(f => buildTypedFunction(c)(f)) structGeneration(c)(functionDefs, annottees: _*) } + /** + * Methods that construct the code and build the syntax tree + * @param c Context used for code gen + * @param function Case class that store all information of the single function + * @return Generated syntax tree + */ protected def buildTypedFunction(c: blackbox.Context) (function: Func): c.universe.DefDef = { import c.universe._ diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala index 9bf0818c14a4..29206247296d 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala @@ -25,12 +25,23 @@ import scala.language.experimental.macros import scala.reflect.macros.blackbox private[mxnet] class AddJNDArrayAPIs(isContrib: Boolean) extends StaticAnnotation { - private[mxnet] def macroTransform(annottees: Any*): Any = macro JavaNDArrayMacro.typeSafeAPIDefs +/** + * Generate typesafe method for Java NDArray operations + * @param annottees Annottees used to define Class or Module + * @return Generated code for injection + */ + private[mxnet] def macroTransform(annottees: Any*) = macro JavaNDArrayMacro.typeSafeAPIDefs } private[mxnet] object JavaNDArrayMacro extends GeneratorBase { - def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Nothing] = { + /** + * Methods that call code generation + * @param c Context used for code gen + * @param annottees Annottees used to define Class or Module + * @return Generated code for injection + */ + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = { typeSafeAPIImpl(c)(annottees: _*) } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala index 12d797f9b100..c984d07143ef 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala @@ -31,7 +31,15 @@ private[mxnet] object CToScalaUtils { "double" -> "Double", "bool" -> "Boolean") - // Convert C++ Types to Scala Types + /** + * Convert C++ Types to Scala Types + * @param in Input raw string that contains C type docs + * @param argType Arg type that used for error messaging + * @param argName Arg name used for error messaging + * @param returnType The type that NDArray/Symbol should be + * @param isJava Check if generating for Java + * @return String that contains right Scala/Java types + */ def typeConversion(in : String, argType : String = "", argName : String, returnType : String, isJava : Boolean) : String = { val header = returnType.split("\\.").dropRight(1) @@ -64,6 +72,8 @@ private[mxnet] object CToScalaUtils { * optional, what is it Scala type and possibly pass in a default value * @param argName The name of the argument * @param argType Raw arguement Type description + * @param returnType Return type of the function (Symbol/NDArray) + * @param isJava Check if Java args should be generated * @return (Scala_Type, isOptional) */ def argumentCleaner(argName: String, argType : String, diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/OperatorBuildUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/OperatorBuildUtils.scala deleted file mode 100644 index 383c68c0fb10..000000000000 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/OperatorBuildUtils.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mxnet.utils - -private[mxnet] object OperatorBuildUtils { - // Convert ctypes returned doc string information into parameters docstring. - def ctypes2docstring(argNames: Seq[String], - argTypes: Seq[String], - argDescs: Seq[String]): String = { - val params = - (argNames zip argTypes zip argDescs) map { case ((argName, argType), argDesc) => - val desc = if (argDesc.isEmpty) "" else s"\n$argDesc" - s"$argName : $argType$desc" - } - s"Parameters\n----------\n${params.mkString("\n")}\n" - } -}