Skip to content

Commit

Permalink
[MXNET-918] Introduce Random module / Refact code generation (apache#…
Browse files Browse the repository at this point in the history
…13038)

* refactor code gen

* remove xxxAPIMacroBase (overkill)

* CI errors / scala-style

* PR review comments
  • Loading branch information
mdespriee authored and azai91 committed Dec 1, 2018
1 parent 8e3bfea commit c86b103
Show file tree
Hide file tree
Showing 4 changed files with 411 additions and 493 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,178 +17,154 @@

package org.apache.mxnet

import org.apache.mxnet.init.Base._
import org.apache.mxnet.utils.CToScalaUtils
import java.io._
import java.security.MessageDigest

import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.collection.mutable.ListBuffer

/**
* This object will generate the Scala documentation of the new Scala API
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
* The code will be executed during Macros stage and file live in Core stage
*/
private[mxnet] object APIDocGenerator{
case class absClassArg(argName : String, argType : String, argDesc : String, isOptional : Boolean)
case class absClassFunction(name : String, desc : String,
listOfArgs: List[absClassArg], returnType : String)
private[mxnet] object APIDocGenerator extends GeneratorBase {


def main(args: Array[String]) : Unit = {
def main(args: Array[String]): Unit = {
val FILE_PATH = args(0)
val hashCollector = ListBuffer[String]()
hashCollector += absClassGen(FILE_PATH, true)
hashCollector += absClassGen(FILE_PATH, false)
hashCollector += typeSafeClassGen(FILE_PATH, true)
hashCollector += typeSafeClassGen(FILE_PATH, false)
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
val finalHash = hashCollector.mkString("\n")
}

def MD5Generator(input : String) : String = {
def MD5Generator(input: String): String = {
val md = MessageDigest.getInstance("MD5")
md.update(input.getBytes("UTF-8"))
val digest = md.digest()
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}

def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
// Defines Operators that should not generated
val notGenerated = Set("Custom")
// TODO: Add Filter to the same location in case of refactor
val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
.filterNot(ele => notGenerated.contains(ele.name))
.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
val defBody = generateAPISignature(absClassFunction, isSymbol)
s"$scalaDoc\n$defBody"
})
val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase"
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false)
.map { func =>
val scalaDoc = generateAPIDocFromBackend(func)
val decl = generateAPISignature(func, isSymbol)
s"$scalaDoc\n$decl"
}

writeFile(
FILE_PATH,
if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
"package org.apache.mxnet",
generated)
}

def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
val absFuncs = absClassFunctions.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
if (isSymbol) {
val defBody = s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): org.apache.mxnet.Symbol"
s"$scalaDoc\n$defBody"
} else {
val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
val defBody = s"def ${absClassFunction.name}(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val absFuncs = functionsToGenerate(isSymbol, isContrib = false)
.map { func =>
val scalaDoc = generateAPIDocFromBackend(func, false)
if (isSymbol) {
s"""$scalaDoc
|def ${func.name}(name : String = null, attr : Map[String, String] = null)
| (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null):
| org.apache.mxnet.Symbol
""".stripMargin
} else {
s"""$scalaDoc
|def ${func.name}(kwargs: Map[String, Any] = null)
| (args: Any*): org.apache.mxnet.NDArrayFuncReturn
|
|$scalaDoc
|def ${func.name}(args: Any*): org.apache.mxnet.NDArrayFuncReturn
""".stripMargin
}
}
})
val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
import java.io._
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)

writeFile(
FILE_PATH,
if (isSymbol) "SymbolBase" else "NDArrayBase",
"package org.apache.mxnet",
absFuncs)
}

// Generate ScalaDoc type
def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = true) : String = {
val desc = ArrayBuffer[String]()
desc += " * <pre>"
func.desc.split("\n").foreach({ currStr =>
desc += s" * $currStr"
})
desc += " * </pre>"
val params = func.listOfArgs.map({ absClassArg =>
val currArgName = absClassArg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case _ => absClassArg.argName
}
s" * @param $currArgName\t\t${absClassArg.argDesc}"
})
def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = {
val desc = func.desc.split("\n")
.mkString(" * <pre>\n", "\n * ", " * </pre>\n")

val params = func.listOfArgs.map { absClassArg =>
s" * @param ${absClassArg.safeArgName}\t\t${absClassArg.argDesc}"
}

val returnType = s" * @return ${func.returnType}"

if (withParam) {
s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */"
s""" /**
|$desc
|${params.mkString("\n")}
|$returnType
| */""".stripMargin
} else {
s" /**\n${desc.mkString("\n")}\n$returnType\n */"
s""" /**
|$desc
|$returnType
| */""".stripMargin
}
}

def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
var argDef = ListBuffer[String]()
func.listOfArgs.foreach(absClassArg => {
val currArgName = absClassArg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case _ => absClassArg.argName
}
if (absClassArg.isOptional) {
argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
}
else {
argDef += s"$currArgName : ${absClassArg.argType}"
}
})
var returnType = func.returnType
def generateAPISignature(func: Func, isSymbol: Boolean): String = {
val argDef = ListBuffer[String]()

argDef ++= typedFunctionCommonArgDef(func)

if (isSymbol) {
argDef += "name : String = null"
argDef += "attr : Map[String, String] = null"
} else {
argDef += "out : Option[NDArray] = None"
returnType = "org.apache.mxnet.NDArrayFuncReturn"
}
val experimentalTag = "@Experimental"
s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType"
}

val returnType = func.returnType

// List and add all the atomic symbol functions to current module.
private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = {
val opNames = ListBuffer.empty[String]
val returnType = if (isSymbol) "Symbol" else "NDArray"
_LIB.mxListAllOpNames(opNames)
// TODO: Add '_linalg_', '_sparse_', '_image_' support
// TODO: Add Filter to the same location in case of refactor
opNames.map(opName => {
val opHandle = new RefLong
_LIB.nnGetOpHandle(opName, opHandle)
makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType)
}).toList.filterNot(_.name.startsWith("_"))
s"""@Experimental
|def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
}

// Create an atomic symbol function by handle and function name.
private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String)
: absClassFunction = {
val name = new RefString
val desc = new RefString
val keyVarNumArgs = new RefString
val numArgs = new RefInt
val argNames = ListBuffer.empty[String]
val argTypes = ListBuffer.empty[String]
val argDescs = ListBuffer.empty[String]

_LIB.mxSymbolGetAtomicSymbolInfo(
handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, returnType)
new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
}
new absClassFunction(aliasName, desc.value, argList.toList, returnType)
def writeFile(FILE_PATH: String, className: String, packageDef: String,
absFuncs: Seq[String]): String = {

val finalStr =
s"""/*
|* Licensed to the Apache Software Foundation (ASF) under one or more
|* contributor license agreements. See the NOTICE file distributed with
|* this work for additional information regarding copyright ownership.
|* The ASF licenses this file to You under the Apache License, Version 2.0
|* (the "License"); you may not use this file except in compliance with
|* the License. You may obtain a copy of the License at
|*
|* http://www.apache.org/licenses/LICENSE-2.0
|*
|* Unless required by applicable law or agreed to in writing, software
|* distributed under the License is distributed on an "AS IS" BASIS,
|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|* See the License for the specific language governing permissions and
|* limitations under the License.
|*/
|
|$packageDef
|
|import org.apache.mxnet.annotation.Experimental
|
|// scalastyle:off
|abstract class $className {
|${absFuncs.mkString("\n")}
|}""".stripMargin

val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
}

}
Loading

0 comments on commit c86b103

Please sign in to comment.