Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
refactoring : extract GeneratorBase
Browse files Browse the repository at this point in the history
  • Loading branch information
mdespriee committed Oct 26, 2018
1 parent 053f9fa commit 6cbd132
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 357 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.mxnet

import org.scalatest.{BeforeAndAfterAll, FunSuite}


class SymbolSuite extends FunSuite with BeforeAndAfterAll {

test("symbol compose") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.mxnet

import org.apache.mxnet.init.Base._
import org.apache.mxnet.utils.CToScalaUtils
import java.io._
import java.security.MessageDigest

Expand All @@ -29,13 +27,11 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
* The code will be executed during Macros stage and file live in Core stage
*/
private[mxnet] object APIDocGenerator{
case class absClassArg(argName : String, argType : String, argDesc : String, isOptional : Boolean)
case class absClassFunction(name : String, desc : String,
listOfArgs: List[absClassArg], returnType : String)
private[mxnet] object APIDocGenerator extends GeneratorBase {
type absClassArg = Arg
type absClassFunction = Func


def main(args: Array[String]) : Unit = {
def main(args: Array[String]): Unit = {
val FILE_PATH = args(0)
val hashCollector = ListBuffer[String]()
hashCollector += absClassGen(FILE_PATH, true)
Expand All @@ -47,68 +43,70 @@ private[mxnet] object APIDocGenerator{
val finalHash = hashCollector.mkString("\n")
}

def MD5Generator(input : String) : String = {
def MD5Generator(input: String): String = {
val md = MessageDigest.getInstance("MD5")
md.update(input.getBytes("UTF-8"))
val digest = md.digest()
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}

def absRndClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
typeSafeClassGen(
getSymbolNDArrayMethods(isSymbol)
.filter(f => f.name.startsWith("_random") || f.name.startsWith("_sample"))
.map(f => f.copy(name = f.name.stripPrefix("_"))),
def absRndClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val funcs = getSymbolNDArrayMethods(isSymbol)
.filter(f => f.name.startsWith("_sample_") || f.name.startsWith("_random_"))
.map(f => f.copy(name = f.name.stripPrefix("_")))
val body = funcs.map(func => {
val scalaDoc = generateAPIDocFromBackend(func)
val decl = generateRandomAPISignature(func, isSymbol)
s"$scalaDoc\n$decl"
})
writeFile(
FILE_PATH,
if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase",
isSymbol
)
body)
}

def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
def absClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val notGenerated = Set("Custom")
typeSafeClassGen(
getSymbolNDArrayMethods(isSymbol)
.filterNot(_.name.startsWith("_"))
.filterNot(ele => notGenerated.contains(ele.name)),
val funcs = getSymbolNDArrayMethods(isSymbol)
.filterNot(_.name.startsWith("_"))
.filterNot(ele => notGenerated.contains(ele.name))
val body = funcs.map(func => {
val scalaDoc = generateAPIDocFromBackend(func)
val decl = generateAPISignature(func, isSymbol)
s"$scalaDoc\n$decl"
})
writeFile(
FILE_PATH,
if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
isSymbol
)
body)
}

def typeSafeClassGen(absClassFunctions: Seq[absClassFunction], FILE_PATH: String,
packageName: String, isSymbol: Boolean): String = {
val absFuncs = absClassFunctions
.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
val defBody = generateAPISignature(absClassFunction, isSymbol)
s"$scalaDoc\n$defBody"
})
writeFile(FILE_PATH, packageName, absFuncs)
}

def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
// scalastyle:off
def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
val absFuncs = absClassFunctions
.filterNot(_.name.startsWith("_"))
.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
if (isSymbol) {
val defBody = s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): org.apache.mxnet.Symbol"
s"$scalaDoc\n$defBody"
} else {
val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
val defBody = s"def ${absClassFunction.name}(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
}
})
val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
if (isSymbol) {
val defBody =
s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)" +
s"(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): " +
s"org.apache.mxnet.Symbol"
s"$scalaDoc\n$defBody"
} else {
val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)" +
s"(args: Any*): " +
s"org.apache.mxnet.NDArrayFuncReturn"
val defBody = s"def ${absClassFunction.name}(args: Any*): " +
s"org.apache.mxnet.NDArrayFuncReturn"
s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
}
})
val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
writeFile(FILE_PATH, packageName, absFuncs)
}

def writeFile(FILE_PATH: String, packageName: String, absFuncs: Seq[String]): String = {
def writeFile(FILE_PATH: String, packageName: String, body: Seq[String]): String = {
val apacheLicence =
"""/*
|* Licensed to the Apache Software Foundation (ASF) under one or more
Expand Down Expand Up @@ -137,7 +135,7 @@ private[mxnet] object APIDocGenerator{
|$packageDef
|$imports
|$absClassDef {
|${absFuncs.mkString("\n")}
|${body.mkString("\n")}
|}""".stripMargin
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
Expand All @@ -146,20 +144,15 @@ private[mxnet] object APIDocGenerator{
}

// Generate ScalaDoc type
def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = true) : String = {
def generateAPIDocFromBackend(func: absClassFunction, withParam: Boolean = true): String = {
val desc = ArrayBuffer[String]()
desc += " * <pre>"
func.desc.split("\n").foreach({ currStr =>
func.desc.split("\n").foreach({ currStr =>
desc += s" * $currStr"
})
desc += " * </pre>"
val params = func.listOfArgs.map({ absClassArg =>
val currArgName = absClassArg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case _ => absClassArg.argName
}
s" * @param $currArgName\t\t${absClassArg.argDesc}"
s" * @param ${absClassArg.safeArgName}\t\t${absClassArg.argDesc}"
})
val returnType = s" * @return ${func.returnType}"
if (withParam) {
Expand All @@ -169,64 +162,31 @@ private[mxnet] object APIDocGenerator{
}
}

def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
var argDef = ListBuffer[String]()
func.listOfArgs.foreach(absClassArg => {
val currArgName = absClassArg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case _ => absClassArg.argName
}
if (absClassArg.isOptional) {
argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
}
else {
argDef += s"$currArgName : ${absClassArg.argType}"
}
})
var returnType = func.returnType
def generateRandomAPISignature(func: absClassFunction, isSymbol: Boolean): String = {
generateAPISignature(func, isSymbol)
}

def generateAPISignature(func: absClassFunction, isSymbol: Boolean): String = {
val argDef = ListBuffer[String]()

argDef ++= buildArgDefs(func)

if (isSymbol) {
argDef += "name : String = null"
argDef += "attr : Map[String, String] = null"
} else {
argDef += "out : Option[NDArray] = None"
returnType = "org.apache.mxnet.NDArrayFuncReturn"
}

val returnType = func.returnType

val experimentalTag = "@Experimental"
s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType"
}

// List and add all the atomic symbol functions to current module.
private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = {
val opNames = ListBuffer.empty[String]
val returnType = if (isSymbol) "Symbol" else "NDArray"
_LIB.mxListAllOpNames(opNames)
// TODO: Add '_linalg_', '_sparse_', '_image_' support
// TODO: Add Filter to the same location in case of refactor
opNames.map(opName => {
val opHandle = new RefLong
_LIB.nnGetOpHandle(opName, opHandle)
makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType)
}).toList
private def getSymbolNDArrayMethods(isSymbol: Boolean): List[absClassFunction] = {
buildFunctionList(isSymbol)
}

// Create an atomic symbol function by handle and function name.
private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String)
: absClassFunction = {
val name = new RefString
val desc = new RefString
val keyVarNumArgs = new RefString
val numArgs = new RefInt
val argNames = ListBuffer.empty[String]
val argTypes = ListBuffer.empty[String]
val argDescs = ListBuffer.empty[String]

_LIB.mxSymbolGetAtomicSymbolInfo(
handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, returnType)
absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
}
absClassFunction(aliasName, desc.value, argList.toList, returnType)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package org.apache.mxnet

import org.apache.mxnet.init.Base.{RefInt, RefLong, RefString, _LIB}
import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}

import scala.collection.mutable.ListBuffer
import scala.reflect.macros.blackbox

abstract class GeneratorBase {
type Handle = Long

case class Arg(argName: String, argType: String, argDesc: String, isOptional: Boolean) {
def safeArgName: String = argName match {
case "var" => "vari"
case "type" => "typeOf"
case _ => argName
}
}

case class Func(name: String, desc: String, listOfArgs: List[Arg], returnType: String)

protected def buildFunctionList(isSymbol: Boolean): List[Func] = {
val opNames = ListBuffer.empty[String]
_LIB.mxListAllOpNames(opNames)
opNames.map(opName => {
val opHandle = new RefLong
_LIB.nnGetOpHandle(opName, opHandle)
makeAtomicFunction(opHandle.value, opName, isSymbol)
}).toList
}

protected def makeAtomicFunction(handle: Handle, aliasName: String, isSymbol: Boolean): Func = {
val name = new RefString
val desc = new RefString
val keyVarNumArgs = new RefString
val numArgs = new RefInt
val argNames = ListBuffer.empty[String]
val argTypes = ListBuffer.empty[String]
val argDescs = ListBuffer.empty[String]

_LIB.mxSymbolGetAtomicSymbolInfo(
handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, argDescs)
val extraDoc: String = if (keyVarNumArgs.value != null && keyVarNumArgs.value.length > 0) {
s"This function support variable length of positional input (${keyVarNumArgs.value})."
} else {
""
}
val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})"
val docStr = s"$aliasName $realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
// scalastyle:off println
if (System.getenv("MXNET4J_PRINT_OP_DEF") != null
&& System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
println("Function definition:\n" + docStr)
}
// scalastyle:on println
val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
val family = if(isSymbol) "org.apache.mxnet.Symbol" else "org.apache.mxnet.NDArray"
val typeAndOption =
CToScalaUtils.argumentCleaner(argName, argType, family)
Arg(argName, typeAndOption._1, argDesc, typeAndOption._2)
}
val returnType = if(isSymbol) "org.apache.mxnet.Symbol" else "org.apache.mxnet.NDArrayFuncReturn"
Func(aliasName, desc.value, argList.toList, returnType)
}

/**
* Generate class structure for all function APIs
*
* @param c
* @param funcDef DefDef type of function definitions
* @param annottees
* @return
*/
protected def structGeneration(c: blackbox.Context)
(funcDef: List[c.universe.DefDef], annottees: c.Expr[Any]*)
: c.Expr[Any] = {
import c.universe._
val inputs = annottees.map(_.tree).toList
// pattern match on the inputs
val modDefs = inputs map {
case ClassDef(mods, name, something, template) =>
val q = template match {
case Template(superMaybe, emptyValDef, defs) =>
Template(superMaybe, emptyValDef, defs ++ funcDef)
case ex =>
throw new IllegalArgumentException(s"Invalid template: $ex")
}
ClassDef(mods, name, something, q)
case ModuleDef(mods, name, template) =>
val q = template match {
case Template(superMaybe, emptyValDef, defs) =>
Template(superMaybe, emptyValDef, defs ++ funcDef)
case ex =>
throw new IllegalArgumentException(s"Invalid template: $ex")
}
ModuleDef(mods, name, q)
case ex =>
throw new IllegalArgumentException(s"Invalid macro input: $ex")
}
// wrap the result up in an Expr, and return it
val result = c.Expr(Block(modDefs, Literal(Constant())))
result
}

protected def buildArgDefs(func: Func): List[String] = {
func.listOfArgs.map(arg =>
if (arg.isOptional)
s"${arg.safeArgName} : Option[${arg.argType}] = None"
else
s"${arg.safeArgName} : ${arg.argType}"
)
}


}
Loading

0 comments on commit 6cbd132

Please sign in to comment.