diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java index 86c7eb29d2ef..1b7042d49795 100644 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java @@ -86,7 +86,7 @@ public void testGenerated(){ NDArray$ NDArray = NDArray$.MODULE$; float[] arr = new float[]{1.0f, 2.0f, 3.0f}; NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0)); - float result = NDArray.norm(NDArray.new normParam(nd))[0].toArray()[0]; + float result = NDArray.norm(new normParam(nd))[0].toArray()[0]; float cal = 0.0f; for (float ele : arr) { cal += ele * ele; @@ -94,7 +94,7 @@ public void testGenerated(){ cal = (float) Math.sqrt(cal); assertTrue(Math.abs(result - cal) < 1e-5); NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0)); - NDArray.dot(NDArray.new dotParam(nd, nd).setOut(dotResult)); + NDArray.dot(new dotParam(nd, nd).setOut(dotResult)); assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f})); } } diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java index b40a4e94afbd..dd17b1d4a0a5 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/bert/BertQA.java @@ -68,15 +68,15 @@ private static int argmax(float[] prob) { */ static List postProcessing(NDArray result, List tokens) { NDArray[] output = NDArray.split( - NDArray.new splitParam(result, 2).setAxis(2)); + new splitParam(result, 2).setAxis(2)); // Get the formatted logits result NDArray startLogits = output[0].reshape(new int[]{0, -3}); NDArray endLogits = output[1].reshape(new int[]{0, -3}); // Get Probability distribution float[] startProb = NDArray.softmax( - NDArray.new softmaxParam(startLogits))[0].toArray(); + new softmaxParam(startLogits))[0].toArray(); float[] endProb = NDArray.softmax( - NDArray.new softmaxParam(endLogits))[0].toArray(); + new softmaxParam(endLogits))[0].toArray(); int startIdx = argmax(startProb); int endIdx = argmax(endProb); return tokens.subList(startIdx, endIdx + 1); diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index a5102d6624ef..e939b2ebf9e7 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -152,7 +152,8 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { def javaClassGen(FILE_PATH : String) : String = { val notGenerated = Set("Custom") val absClassFunctions = functionsToGenerate(false, false, true) - val absFuncs = absClassFunctions.filterNot(ele => notGenerated.contains(ele.name)) + val (absFuncs, paramClassUncleaned) = + absClassFunctions.filterNot(ele => notGenerated.contains(ele.name)) .groupBy(_.name.toLowerCase).map(ele => { /* Pattern matching for not generating deprecated method * Group all method name in lowercase @@ -166,7 +167,8 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { } }).map(absClassFunction => { generateJavaAPISignature(absClassFunction) - }).toSeq + }).toSeq.unzip + val paramClass = paramClassUncleaned.filterNot(_.isEmpty) val packageName = "NDArrayBase" val packageDef = "package org.apache.mxnet.javaapi" writeFile( @@ -174,7 +176,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { packageDef, packageName, "import org.apache.mxnet.annotation.Experimental", - absFuncs) + absFuncs, Some(paramClass)) } /** @@ -248,7 +250,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { * @param func The function case class * @return A formatted string for the function */ - def generateJavaAPISignature(func : Func) : String = { + def generateJavaAPISignature(func : Func) : (String, String) = { val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2 var argDef = ListBuffer[String]() var classDef = ListBuffer[String]() @@ -287,22 +289,23 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { | } | def getOut() = this.out | """.stripMargin - s"""$scalaDocNoParam + (s"""$scalaDocNoParam | $experimentalTag | def ${func.name}(po: ${func.name}Param) : $returnType - | /** + | """.stripMargin, + s"""/** | * This Param Object is specifically used for ${func.name} | ${requiredParam.mkString("\n")} | */ | class ${func.name}Param(${argDef.mkString(",")}) { | ${classDef.mkString("\n ")} - | }""".stripMargin + | }""".stripMargin) } else { argDef += "out : NDArray" - s"""$scalaDoc + (s"""$scalaDoc |$experimentalTag | def ${func.name}(${argDef.mkString(", ")}) : $returnType - | """.stripMargin + | """.stripMargin, "") } } @@ -316,7 +319,8 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { * @return A MD5 string */ def writeFile(FILE_PATH: String, packageDef: String, className: String, - imports: String, absFuncs: Seq[String]): String = { + imports: String, absFuncs: Seq[String], + paramClass: Option[Seq[String]] = None): String = { val finalStr = s"""/* @@ -343,7 +347,9 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers { |// scalastyle:off |abstract class $className { |${absFuncs.mkString("\n")} - |}""".stripMargin + |} + |${paramClass.getOrElse(Seq()).mkString("\n")} + |""".stripMargin val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala")) diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayCreation.java b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayCreation.java index 32e2d84dcdbf..4361c06edf32 100644 --- a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayCreation.java +++ b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayCreation.java @@ -37,7 +37,7 @@ public static void main(String[] args) { // random NDArray random = NDArray.random_uniform( - NDArray.new random_uniformParam() + new random_uniformParam() .setLow(0.0f) .setHigh(2.0f) .setShape(new Shape(new int[]{10, 10})) diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayOperation.java b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayOperation.java index 56a414307f46..646adf5550b1 100644 --- a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayOperation.java +++ b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayOperation.java @@ -38,7 +38,7 @@ public static void main(String[] args) { System.out.println(eleAdd); // norm (L2 Norm) - NDArray normed = NDArray.norm(NDArray.new normParam(nd))[0]; + NDArray normed = NDArray.norm(new normParam(nd))[0]; System.out.println(normed); } }