Skip to content

Commit

Permalink
clean up submodule (apache#14645)
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Apr 29, 2019
1 parent 19d78e5 commit b7bfcfc
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ public void testGenerated(){
NDArray$ NDArray = NDArray$.MODULE$;
float[] arr = new float[]{1.0f, 2.0f, 3.0f};
NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0));
float result = NDArray.norm(NDArray.new normParam(nd))[0].toArray()[0];
float result = NDArray.norm(new normParam(nd))[0].toArray()[0];
float cal = 0.0f;
for (float ele : arr) {
cal += ele * ele;
}
cal = (float) Math.sqrt(cal);
assertTrue(Math.abs(result - cal) < 1e-5);
NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0));
NDArray.dot(NDArray.new dotParam(nd, nd).setOut(dotResult));
NDArray.dot(new dotParam(nd, nd).setOut(dotResult));
assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
def javaClassGen(filePath : String) : String = {
val notGenerated = Set("Custom")
val absClassFunctions = functionsToGenerate(false, false, true)
val absFuncs = absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
val (absFuncs, paramClassUncleaned) =
absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
.groupBy(_.name.toLowerCase).map(ele => {
/* Pattern matching for not generating deprecated method
* Group all method name in lowercase
Expand All @@ -107,10 +108,16 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
}
}).map(absClassFunction => {
generateJavaAPISignature(absClassFunction)
}).toSeq
}).toSeq.unzip
val paramClass = paramClassUncleaned.filterNot(_.isEmpty)
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs)
writeFile(
FILE_PATH + "javaapi/",
packageDef,
packageName,
"import org.apache.mxnet.annotation.Experimental",
absFuncs, Some(paramClass))
}

def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = {
Expand Down Expand Up @@ -165,7 +172,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
|def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
}

def generateJavaAPISignature(func : Func) : String = {
/**
* Generate Java function interface
* @param func The function case class
* @return A formatted string for the function
*/
def generateJavaAPISignature(func : Func) : (String, String) = {
val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2
var argDef = ListBuffer[String]()
var classDef = ListBuffer[String]()
Expand Down Expand Up @@ -204,27 +216,38 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
| }
| def getOut() = this.out
| """.stripMargin
s"""$scalaDocNoParam
(s"""$scalaDocNoParam
| $experimentalTag
| def ${func.name}(po: ${func.name}Param) : $returnType
| /**
| """.stripMargin,
s"""/**
| * This Param Object is specifically used for ${func.name}
| ${requiredParam.mkString("\n")}
| */
| class ${func.name}Param(${argDef.mkString(",")}) {
| ${classDef.mkString("\n ")}
| }""".stripMargin
| }""".stripMargin)
} else {
argDef += "out : NDArray"
s"""$scalaDoc
(s"""$scalaDoc
|$experimentalTag
| def ${func.name}(${argDef.mkString(", ")}) : $returnType
| """.stripMargin
| """.stripMargin, "")
}
}

def writeFile(FILE_PATH: String, className: String, packageDef: String,
absFuncs: Seq[String]): String = {
/**
* Write the formatted string to file
* @param FILE_PATH Location of the file writes to
* @param packageDef Package definition
* @param className Class name
* @param imports Packages need to import
* @param absFuncs All formatted functions
* @return A MD5 string
*/
def writeFile(FILE_PATH: String, packageDef: String, className: String,
imports: String, absFuncs: Seq[String],
paramClass: Option[Seq[String]] = None): String = {

val finalStr =
s"""/*
Expand All @@ -251,7 +274,9 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
|// scalastyle:off
|abstract class $className {
|${absFuncs.mkString("\n")}
|}""".stripMargin
|}
|${paramClass.getOrElse(Seq()).mkString("\n")}
|""".stripMargin


val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala"))
Expand Down

0 comments on commit b7bfcfc

Please sign in to comment.