Skip to content

Commit

Permalink
[MXNET-918] Introduce Random module / Refact code generation (apache#…
Browse files Browse the repository at this point in the history
…13038)

* refactor code gen

* remove xxxAPIMacroBase (overkill)

* CI errors / scala-style

* PR review comments
  • Loading branch information
mdespriee authored and lanking520 committed Nov 13, 2018
1 parent 1bb5b7f commit 251d2cd
Show file tree
Hide file tree
Showing 5 changed files with 440 additions and 549 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public void runBatchInference() {
List<NDArray> nd = new ArrayList<>();
NDArray[] temp = new NDArray[batchSize];
for (int i = 0; i < batchSize; i++) temp[i] = img.copy();
NDArray batched = NDArray.concat(temp, batchSize).setdim(0).invoke().get();
NDArray batched = NDArray.concat(temp, batchSize, 0, null)[0];
nd.add(batched);
objDet.objectDetectWithNDArray(nd, 3);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,196 +17,147 @@

package org.apache.mxnet

import org.apache.mxnet.init.Base._
import org.apache.mxnet.utils.CToScalaUtils
import java.io._
import java.security.MessageDigest

import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.collection.mutable.ListBuffer

/**
* This object will generate the Scala documentation of the new Scala API
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
* The code will be executed during Macros stage and file live in Core stage
*/
private[mxnet] object APIDocGenerator{
case class absClassArg(argName : String, argType : String, argDesc : String, isOptional : Boolean)
case class absClassFunction(name : String, desc : String,
listOfArgs: List[absClassArg], returnType : String)
private[mxnet] object APIDocGenerator extends GeneratorBase {


def main(args: Array[String]) : Unit = {
def main(args: Array[String]): Unit = {
val FILE_PATH = args(0)
val hashCollector = ListBuffer[String]()
hashCollector += absClassGen(FILE_PATH, true)
hashCollector += absClassGen(FILE_PATH, false)
hashCollector += typeSafeClassGen(FILE_PATH, true)
hashCollector += typeSafeClassGen(FILE_PATH, false)
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
// Generate Java API documentation
hashCollector += javaClassGen(FILE_PATH + "javaapi/")
hashCollector += javaClassGen(FILE_PATH)
val finalHash = hashCollector.mkString("\n")
}

def MD5Generator(input : String) : String = {
def MD5Generator(input: String): String = {
val md = MessageDigest.getInstance("MD5")
md.update(input.getBytes("UTF-8"))
val digest = md.digest()
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}

def fileGen(filePath : String, packageName : String, packageDef : String,
absFuncs : List[String]) : String = {
val apacheLicense =
"""/*
|* Licensed to the Apache Software Foundation (ASF) under one or more
|* contributor license agreements. See the NOTICE file distributed with
|* this work for additional information regarding copyright ownership.
|* The ASF licenses this file to You under the Apache License, Version 2.0
|* (the "License"); you may not use this file except in compliance with
|* the License. You may obtain a copy of the License at
|*
|* http://www.apache.org/licenses/LICENSE-2.0
|*
|* Unless required by applicable law or agreed to in writing, software
|* distributed under the License is distributed on an "AS IS" BASIS,
|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|* See the License for the specific language governing permissions and
|* limitations under the License.
|*/
|""".stripMargin
val scalaStyle = "// scalastyle:off"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false)
.map { func =>
val scalaDoc = generateAPIDocFromBackend(func)
val decl = generateAPISignature(func, isSymbol)
s"$scalaDoc\n$decl"
}

val finalStr =
s"""$apacheLicense
|$scalaStyle
|$packageDef
|$imports
|$absClassDef {
|${absFuncs.mkString("\n")}
|}""".stripMargin
val pw = new PrintWriter(new File(filePath + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
writeFile(
FILE_PATH,
if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
"package org.apache.mxnet",
generated)
}

def absClassGen(filePath : String, isSymbol : Boolean) : String = {
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
// Defines Operators that should not generated
val notGenerated = Set("Custom")
// TODO: Add Filter to the same location in case of refactor
val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
.filterNot(ele => notGenerated.contains(ele.name))
.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
val defBody = generateAPISignature(absClassFunction, isSymbol)
s"$scalaDoc\n$defBody"
})
val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase"
val packageDef = "package org.apache.mxnet"
fileGen(filePath, packageName, packageDef, absFuncs)
def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val absFuncs = functionsToGenerate(isSymbol, isContrib = false)
.map { func =>
val scalaDoc = generateAPIDocFromBackend(func, false)
if (isSymbol) {
s"""$scalaDoc
|def ${func.name}(name : String = null, attr : Map[String, String] = null)
| (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null):
| org.apache.mxnet.Symbol
""".stripMargin
} else {
s"""$scalaDoc
|def ${func.name}(kwargs: Map[String, Any] = null)
| (args: Any*): org.apache.mxnet.NDArrayFuncReturn
|
|$scalaDoc
|def ${func.name}(args: Any*): org.apache.mxnet.NDArrayFuncReturn
""".stripMargin
}
}

writeFile(
FILE_PATH,
if (isSymbol) "SymbolBase" else "NDArrayBase",
"package org.apache.mxnet",
absFuncs)
}

def javaClassGen(filePath : String) : String = {
val notGenerated = Set("Custom")
val absClassFunctions = getSymbolNDArrayMethods(false, true)
// TODO: Add Filter to the same location in case of refactor
val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
.filterNot(ele => notGenerated.contains(ele.name))
.map(absClassFunction => {
val absClassFunctions = functionsToGenerate(false, false, true)
val absFuncs = absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
.groupBy(_.name.toLowerCase).map(ele => {
// Pattern matching for not generating depreciated method
if (ele._2.length == 1) ele._2.head
else {
if (ele._2.head.name.head.isLower) ele._2.head
else ele._2.last
}
}).map(absClassFunction => {
generateJavaAPISignature(absClassFunction)
})
}).toSeq
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
fileGen(filePath, packageName, packageDef, absFuncs)
writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs)
}

def nonTypeSafeClassGen(filePath : String, isSymbol : Boolean) : String = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
val absFuncs = absClassFunctions.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
if (isSymbol) {
val defBody = s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): org.apache.mxnet.Symbol"
s"$scalaDoc\n$defBody"
} else {
val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
val defBody = s"def ${absClassFunction.name}(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
}
})
val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
val packageDef = "package org.apache.mxnet"
fileGen(filePath, packageName, packageDef, absFuncs)
}
def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = {
val desc = func.desc.split("\n")
.mkString(" * <pre>\n", "\n * ", " * </pre>\n")

/**
* Some of the C++ type name is not valid in Scala
* such as var and type. This method is to convert
* them into other names to get it passed
* @param in the input String
* @return converted name string
*/
def safetyNameCheck(in : String) : String = {
in match {
case "var" => "vari"
case "type" => "typeOf"
case _ => in
val params = func.listOfArgs.map { absClassArg =>
s" * @param ${absClassArg.safeArgName}\t\t${absClassArg.argDesc}"
}
}

// Generate ScalaDoc type
def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = true) : String = {
val desc = ArrayBuffer[String]()
desc += " * <pre>"
func.desc.split("\n").foreach({ currStr =>
desc += s" * $currStr"
})
desc += " * </pre>"
val params = func.listOfArgs.map({ absClassArg =>
val currArgName = safetyNameCheck(absClassArg.argName)
s" * @param $currArgName\t\t${absClassArg.argDesc}"
})
val returnType = s" * @return ${func.returnType}"

if (withParam) {
s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */"
s""" /**
|$desc
|${params.mkString("\n")}
|$returnType
| */""".stripMargin
} else {
s" /**\n${desc.mkString("\n")}\n$returnType\n */"
s""" /**
|$desc
|$returnType
| */""".stripMargin
}
}

def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
var argDef = ListBuffer[String]()
func.listOfArgs.foreach(absClassArg => {
val currArgName = safetyNameCheck(absClassArg.argName)
if (absClassArg.isOptional) {
argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
}
else {
argDef += s"$currArgName : ${absClassArg.argType}"
}
})
var returnType = func.returnType
def generateAPISignature(func: Func, isSymbol: Boolean): String = {
val argDef = ListBuffer[String]()

argDef ++= typedFunctionCommonArgDef(func)

if (isSymbol) {
argDef += "name : String = null"
argDef += "attr : Map[String, String] = null"
} else {
argDef += "out : Option[NDArray] = None"
returnType = "org.apache.mxnet.NDArrayFuncReturn"
}
val experimentalTag = "@Experimental"
s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType"

val returnType = func.returnType

s"""@Experimental
|def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
}

def generateJavaAPISignature(func : absClassFunction) : String = {
def generateJavaAPISignature(func : Func) : String = {
val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2
var argDef = ListBuffer[String]()
var classDef = ListBuffer[String]()
var requiredParam = ListBuffer[String]()
func.listOfArgs.foreach(absClassArg => {
val currArgName = safetyNameCheck(absClassArg.argName)
val currArgName = absClassArg.safeArgName
// scalastyle:off
if (absClassArg.isOptional && useParamObject) {
classDef +=
Expand Down Expand Up @@ -240,15 +191,15 @@ private[mxnet] object APIDocGenerator{
| def getOut() = this.out
| """.stripMargin
s"""$scalaDocNoParam
| $experimentalTag
| def ${func.name}(po: ${func.name}Param) : $returnType
| /**
| * This Param Object is specifically used for ${func.name}
| ${requiredParam.mkString("\n")}
| */
| class ${func.name}Param(${argDef.mkString(",")}) {
| ${classDef.mkString("\n ")}
| }""".stripMargin
| $experimentalTag
| def ${func.name}(po: ${func.name}Param) : $returnType
| /**
| * This Param Object is specifically used for ${func.name}
| ${requiredParam.mkString("\n")}
| */
| class ${func.name}Param(${argDef.mkString(",")}) {
| ${classDef.mkString("\n ")}
| }""".stripMargin
} else {
argDef += "out : NDArray"
s"""$scalaDoc
Expand All @@ -258,48 +209,40 @@ private[mxnet] object APIDocGenerator{
}
}

def writeFile(FILE_PATH: String, className: String, packageDef: String,
absFuncs: Seq[String]): String = {

// List and add all the atomic symbol functions to current module.
private def getSymbolNDArrayMethods(isSymbol : Boolean,
isJava : Boolean = false): List[absClassFunction] = {
val opNames = ListBuffer.empty[String]
val returnType = if (isSymbol) "Symbol" else "NDArray"
val returnHeader = if (isJava) "org.apache.mxnet.javaapi." else "org.apache.mxnet."
_LIB.mxListAllOpNames(opNames)
// TODO: Add '_linalg_', '_sparse_', '_image_' support
// TODO: Add Filter to the same location in case of refactor
opNames.map(opName => {
val opHandle = new RefLong
_LIB.nnGetOpHandle(opName, opHandle)
makeAtomicSymbolFunction(opHandle.value, opName, returnHeader + returnType)
}).filterNot(_.name.startsWith("_")).groupBy(_.name.toLowerCase).map(ele => {
// Pattern matching for not generating depreciated method
if (ele._2.length == 1) ele._2.head
else {
if (ele._2.head.name.head.isLower) ele._2.head
else ele._2.last
}
}).toList
}

// Create an atomic symbol function by handle and function name.
private def makeAtomicSymbolFunction(handle: SymbolHandle,
aliasName: String, returnType : String)
: absClassFunction = {
val name = new RefString
val desc = new RefString
val keyVarNumArgs = new RefString
val numArgs = new RefInt
val argNames = ListBuffer.empty[String]
val argTypes = ListBuffer.empty[String]
val argDescs = ListBuffer.empty[String]
val finalStr =
s"""/*
|* Licensed to the Apache Software Foundation (ASF) under one or more
|* contributor license agreements. See the NOTICE file distributed with
|* this work for additional information regarding copyright ownership.
|* The ASF licenses this file to You under the Apache License, Version 2.0
|* (the "License"); you may not use this file except in compliance with
|* the License. You may obtain a copy of the License at
|*
|* http://www.apache.org/licenses/LICENSE-2.0
|*
|* Unless required by applicable law or agreed to in writing, software
|* distributed under the License is distributed on an "AS IS" BASIS,
|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|* See the License for the specific language governing permissions and
|* limitations under the License.
|*/
|
|$packageDef
|
|import org.apache.mxnet.annotation.Experimental
|
|// scalastyle:off
|abstract class $className {
|${absFuncs.mkString("\n")}
|}""".stripMargin

_LIB.mxSymbolGetAtomicSymbolInfo(
handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, returnType)
new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
}
new absClassFunction(aliasName, desc.value, argList.toList, returnType)
val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
}
}

}
Loading

0 comments on commit 251d2cd

Please sign in to comment.