From b8ac3a447bdd6cf077202313a79a22fd3b357783 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 7 Nov 2018 00:26:23 -0800 Subject: [PATCH 1/8] applying changes for Builder functions --- .../org/apache/mxnet/javaapi/NDArrayTest.java | 6 +- .../org/apache/mxnet/APIDocGenerator.scala | 30 ++++++--- .../mxnet/javaapi/JavaNDArrayMacro.scala | 66 ++++++++++++------- 3 files changed, 68 insertions(+), 34 deletions(-) diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java index a9bad83f62d6..9ba69f285e75 100644 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java @@ -19,9 +19,9 @@ import org.junit.Test; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.apache.mxnet.javaapi.NDArray.*; import static org.junit.Assert.assertTrue; @@ -71,7 +71,7 @@ public void testGenerated(){ NDArray$ NDArray = NDArray$.MODULE$; float[] arr = new float[]{1.0f, 2.0f, 3.0f}; NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0)); - float result = NDArray.norm(nd).invoke().get().toArray()[0]; + float result = NDArray.norm(new normParam(nd)).get().toArray()[0]; float cal = 0.0f; for (float ele : arr) { cal += ele * ele; @@ -79,7 +79,7 @@ public void testGenerated(){ cal = (float) Math.sqrt(cal); assertTrue(Math.abs(result - cal) < 1e-5); NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0)); - NDArray.dot(nd, nd).setout(dotResult).invoke().get(); + NDArray.dot(new dotParam(nd, nd).setOut(dotResult)).get(); assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f})); } } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 44d47a2099d5..d454bf65bc0f 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -203,26 +203,40 @@ private[mxnet] object APIDocGenerator{ } def generateJavaAPISignature(func : absClassFunction) : String = { + val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2 var argDef = ListBuffer[String]() var classDef = ListBuffer[String]() func.listOfArgs.foreach(absClassArg => { val currArgName = safetyNameCheck(absClassArg.argName) // scalastyle:off - if (absClassArg.isOptional) { - classDef += s"def set${absClassArg.argName}(${absClassArg.argName} : ${absClassArg.argType}) : ${func.name}BuilderBase" + if (absClassArg.isOptional && useParamObject) { + classDef += s"def set${absClassArg.argName.capitalize}(${absClassArg.argName}: ${absClassArg.argType}): ${func.name}ParamBase" } else { argDef += s"$currArgName : ${absClassArg.argType}" } // scalastyle:on }) - classDef += s"def setout(out : NDArray) : ${func.name}BuilderBase" - classDef += s"def invoke() : org.apache.mxnet.javaapi.NDArrayFuncReturn" val experimentalTag = "@Experimental" - // scalastyle:off - var finalStr = s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : ${func.name}BuilderBase\n" - // scalastyle:on - finalStr += s"abstract class ${func.name}BuilderBase {\n ${classDef.mkString("\n ")}\n}" + val returnType = "org.apache.mxnet.javaapi.NDArrayFuncReturn" + var finalStr = "" + if(useParamObject) { + classDef += s"def setOut(out : NDArray) : ${func.name}ParamBase" + classDef += s"def invoke() : $returnType" + finalStr = s"""$experimentalTag + | def ${func.name}(po: ${func.name}ParamBase) : $returnType + | """.stripMargin + finalStr += + s"""abstract class ${func.name}ParamBase(${argDef.mkString(",")}) { + | ${classDef.mkString("\n ")} + |}""".stripMargin + } else { + argDef += "out : NDArray" + finalStr = + s"""$experimentalTag + | def ${func.name}(${argDef.mkString(", ")}) : $returnType + | """.stripMargin + } finalStr } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala index d5be97b501c5..035305c256fc 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala @@ -68,6 +68,8 @@ private[mxnet] object JavaNDArrayMacro { newNDArrayFunctions.foreach { ndarrayfunction => + val useParamObject = ndarrayfunction.listOfArgs.count(arg => arg.isOptional) >= 2 + val header = if (useParamObject) "this." else "" // Construct argument field with all required args var argDef = ListBuffer[String]() // Construct Optional Arg @@ -75,9 +77,8 @@ private[mxnet] object JavaNDArrayMacro { // Construct function Implementation field (e.g norm) var impl = ListBuffer[String]() impl += "val map = scala.collection.mutable.Map[String, Any]()" - // scalastyle:off - impl += "val args= scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]" - // scalastyle:on + impl += + "val args= scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]" // Construct Class Implementation (e.g normBuilder) var classImpl = ListBuffer[String]() ndarrayfunction.listOfArgs.foreach({ ndarrayArg => @@ -88,9 +89,9 @@ private[mxnet] object JavaNDArrayMacro { case "type" => "typeOf" case _ => ndarrayArg.argName } - if (ndarrayArg.isOptional) { + if (ndarrayArg.isOptional && useParamObject) { OptionArgDef += s"private var $currArgName : ${ndarrayArg.argType} = null" - val tempDef = s"def set$currArgName($currArgName : ${ndarrayArg.argType})" + val tempDef = s"def set${currArgName.capitalize}($currArgName : ${ndarrayArg.argType})" val tempImpl = s"this.$currArgName = $currArgName\nthis" classImpl += s"$tempDef = {$tempImpl}" } else { @@ -100,35 +101,54 @@ private[mxnet] object JavaNDArrayMacro { val returnType = "org.apache.mxnet.javaapi.NDArray" val base = if (ndarrayArg.argType.equals(returnType)) { - s"args += this.$currArgName" + s"args += $header$currArgName" } else if (ndarrayArg.argType.equals(s"Array[$returnType]")){ - s"this.$currArgName.foreach(args+=_)" + s"$header$currArgName.foreach(args+=_)" } else { - "map(\"" + ndarrayArg.argName + "\") = this." + currArgName + "map(\"" + ndarrayArg.argName + "\") = " + header + currArgName } impl.append( - if (ndarrayArg.isOptional) s"if (this.$currArgName != null) $base" + if (ndarrayArg.isOptional) s"if ($header$currArgName != null) $base" else base ) }) // add default out parameter - classImpl += - "def setout(out : org.apache.mxnet.javaapi.NDArray) = {this.out = out\nthis}" - impl += "if (this.out != null) map(\"out\") = this.out" - OptionArgDef += "private var out : org.apache.mxnet.NDArray = null" + if (useParamObject) { + classImpl += + "def setOut(out : org.apache.mxnet.javaapi.NDArray) = {this.out = out\nthis}" + } else { + argDef += s"out: org.apache.mxnet.javaapi.NDArray" + } + impl += "if (" + header + "out != null) map(\"out\") = " + header + "out" + OptionArgDef += "var out : org.apache.mxnet.NDArray = null" val returnType = "org.apache.mxnet.javaapi.NDArrayFuncReturn" // scalastyle:off // Combine and build the function string - impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)" - val classDef = s"class ${ndarrayfunction.name}Builder(${argDef.mkString(",")}) extends ${ndarrayfunction.name}BuilderBase" - val classBody = s"${OptionArgDef.mkString("\n")}\n${classImpl.mkString("\n")}\ndef invoke() : $returnType = {${impl.mkString("\n")}}" - val classFinal = s"$classDef {$classBody}" - val functionDef = s"def ${ndarrayfunction.name} (${argDef.mkString(",")})" - val functionBody = s"new ${ndarrayfunction.name}Builder(${argDef.map(_.split(":")(0)).mkString(",")})" - val functionFinal = s"$functionDef : ${ndarrayfunction.name}BuilderBase = $functionBody" - // scalastyle:on - functionDefs += c.parse(functionFinal).asInstanceOf[DefDef] - classDefs += c.parse(classFinal).asInstanceOf[ClassDef] + impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + + ndarrayfunction.name + "\", args.toSeq, map.toMap)" + if (useParamObject) { + val classDef = + s"""class ${ndarrayfunction.name}Param(${argDef.mkString(",")}) + | extends ${ndarrayfunction.name}ParamBase(${argDef.mkString(",")}) { + | ${OptionArgDef.mkString("\n")} + | ${classImpl.mkString("\n")} + | def invoke() : $returnType = { + | ${impl.mkString("\n")} + | } + | }""".stripMargin + classDefs += c.parse(classDef).asInstanceOf[ClassDef] + val funcDef = + s"""def ${ndarrayfunction.name}(po: ${ndarrayfunction.name}ParamBase): $returnType = { + | po.invoke() + | }""".stripMargin + functionDefs += c.parse(funcDef).asInstanceOf[DefDef] + } else { + val funcDef = + s"""def ${ndarrayfunction.name}(${argDef.mkString(",")}): $returnType = { + | ${impl.mkString("\n")} + | }""".stripMargin + functionDefs += c.parse(funcDef).asInstanceOf[DefDef] + } } structGeneration(c)(functionDefs.toList, classDefs.toList, annottees : _*) From 9d2ff2268eb47af431a4ed08f1c464ebd53e0afc Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 8 Nov 2018 12:56:57 -0800 Subject: [PATCH 2/8] simplify the code structure --- .../org/apache/mxnet/APIDocGenerator.scala | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index d454bf65bc0f..24ae18cac924 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -219,25 +219,20 @@ private[mxnet] object APIDocGenerator{ }) val experimentalTag = "@Experimental" val returnType = "org.apache.mxnet.javaapi.NDArrayFuncReturn" - var finalStr = "" if(useParamObject) { classDef += s"def setOut(out : NDArray) : ${func.name}ParamBase" classDef += s"def invoke() : $returnType" - finalStr = s"""$experimentalTag + s"""$experimentalTag | def ${func.name}(po: ${func.name}ParamBase) : $returnType - | """.stripMargin - finalStr += - s"""abstract class ${func.name}ParamBase(${argDef.mkString(",")}) { - | ${classDef.mkString("\n ")} - |}""".stripMargin + | abstract class ${func.name}ParamBase(${argDef.mkString(",")}) { + | ${classDef.mkString("\n ")} + | }""".stripMargin } else { argDef += "out : NDArray" - finalStr = - s"""$experimentalTag - | def ${func.name}(${argDef.mkString(", ")}) : $returnType - | """.stripMargin + s"""$experimentalTag + | def ${func.name}(${argDef.mkString(", ")}) : $returnType + | """.stripMargin } - finalStr } From 274fcd9359ac8c4f1e6965dc1b95cba35db823b5 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 8 Nov 2018 16:52:05 -0800 Subject: [PATCH 3/8] update docgen --- .../org/apache/mxnet/APIDocGenerator.scala | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 24ae18cac924..312a767101c0 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -210,21 +210,29 @@ private[mxnet] object APIDocGenerator{ val currArgName = safetyNameCheck(absClassArg.argName) // scalastyle:off if (absClassArg.isOptional && useParamObject) { - classDef += s"def set${absClassArg.argName.capitalize}(${absClassArg.argName}: ${absClassArg.argType}): ${func.name}ParamBase" + classDef += + s"""private var $currArgName = null + |def set${currArgName.capitalize}($currArgName : ${absClassArg.argType}): ${func.name}Param = { + | this.$currArgName = $currArgName + | }""".stripMargin } else { argDef += s"$currArgName : ${absClassArg.argType}" } + classDef += s"def get${currArgName.capitalize}() = this.$currArgName" // scalastyle:on }) val experimentalTag = "@Experimental" val returnType = "org.apache.mxnet.javaapi.NDArrayFuncReturn" if(useParamObject) { - classDef += s"def setOut(out : NDArray) : ${func.name}ParamBase" - classDef += s"def invoke() : $returnType" + classDef += + s"""private var out = null + |def setOut(out : NDArray) : ${func.name}Param = { + | this.out = out + | }""".stripMargin s"""$experimentalTag - | def ${func.name}(po: ${func.name}ParamBase) : $returnType - | abstract class ${func.name}ParamBase(${argDef.mkString(",")}) { + | def ${func.name}(po: ${func.name}Param) : $returnType + | class ${func.name}Param(${argDef.mkString(",")}) { | ${classDef.mkString("\n ")} | }""".stripMargin } else { From 1f7e1e639a4fba97d3d65e0a9a8fdd99ba98d088 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 8 Nov 2018 20:05:38 -0800 Subject: [PATCH 4/8] follow Naveen's suggestion --- .../org/apache/mxnet/javaapi/NDArray.scala | 14 ----- .../org/apache/mxnet/javaapi/NDArrayTest.java | 6 +- .../org/apache/mxnet/APIDocGenerator.scala | 12 ++-- .../mxnet/javaapi/JavaNDArrayMacro.scala | 59 ++++++------------- 4 files changed, 29 insertions(+), 62 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala index d4e67f73408e..cdcc292ada63 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala @@ -385,17 +385,3 @@ class NDArray(val nd : org.apache.mxnet.NDArray ) { override def equals(obj: Any): Boolean = nd.equals(obj) override def hashCode(): Int = nd.hashCode } - -object NDArrayFuncReturn { - implicit def toNDFuncReturn(javaFunReturn : NDArrayFuncReturn) - : org.apache.mxnet.NDArrayFuncReturn = javaFunReturn.ndFuncReturn - implicit def toJavaNDFuncReturn(ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) - : NDArrayFuncReturn = new NDArrayFuncReturn(ndFuncReturn) -} - -private[mxnet] class NDArrayFuncReturn(val ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) { - def head : NDArray = ndFuncReturn.head - def get : NDArray = ndFuncReturn.get - def apply(i : Int) : NDArray = ndFuncReturn.apply(i) - // TODO: Add JavaNDArray operational stuff -} diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java index 9ba69f285e75..2659b7848bc6 100644 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java @@ -21,7 +21,7 @@ import java.util.Arrays; import java.util.List; -import org.apache.mxnet.javaapi.NDArray.*; +import org.apache.mxnet.javaapi.NDArrayBase.*; import static org.junit.Assert.assertTrue; @@ -71,7 +71,7 @@ public void testGenerated(){ NDArray$ NDArray = NDArray$.MODULE$; float[] arr = new float[]{1.0f, 2.0f, 3.0f}; NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0)); - float result = NDArray.norm(new normParam(nd)).get().toArray()[0]; + float result = NDArray.norm(NDArray.new normParam(nd))[0].toArray()[0]; float cal = 0.0f; for (float ele : arr) { cal += ele * ele; @@ -79,7 +79,7 @@ public void testGenerated(){ cal = (float) Math.sqrt(cal); assertTrue(Math.abs(result - cal) < 1e-5); NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0)); - NDArray.dot(new dotParam(nd, nd).setOut(dotResult)).get(); + NDArray.dot(NDArray.new dotParam(nd, nd).setOut(dotResult)); assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f})); } } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 312a767101c0..8005e7e49206 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -211,9 +211,10 @@ private[mxnet] object APIDocGenerator{ // scalastyle:off if (absClassArg.isOptional && useParamObject) { classDef += - s"""private var $currArgName = null + s"""private var $currArgName: ${absClassArg.argType} = null |def set${currArgName.capitalize}($currArgName : ${absClassArg.argType}): ${func.name}Param = { | this.$currArgName = $currArgName + | this | }""".stripMargin } else { @@ -223,13 +224,16 @@ private[mxnet] object APIDocGenerator{ // scalastyle:on }) val experimentalTag = "@Experimental" - val returnType = "org.apache.mxnet.javaapi.NDArrayFuncReturn" + val returnType = "Array[NDArray]" if(useParamObject) { classDef += - s"""private var out = null + s"""private var out : org.apache.mxnet.NDArray = null |def setOut(out : NDArray) : ${func.name}Param = { | this.out = out - | }""".stripMargin + | this + | } + | def getOut() = this.out + | """.stripMargin s"""$experimentalTag | def ${func.name}(po: ${func.name}Param) : $returnType | class ${func.name}Param(${argDef.mkString(",")}) { diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala index 035305c256fc..2d1827038afc 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala @@ -69,18 +69,13 @@ private[mxnet] object JavaNDArrayMacro { newNDArrayFunctions.foreach { ndarrayfunction => val useParamObject = ndarrayfunction.listOfArgs.count(arg => arg.isOptional) >= 2 - val header = if (useParamObject) "this." else "" // Construct argument field with all required args var argDef = ListBuffer[String]() - // Construct Optional Arg - var OptionArgDef = ListBuffer[String]() // Construct function Implementation field (e.g norm) var impl = ListBuffer[String]() impl += "val map = scala.collection.mutable.Map[String, Any]()" impl += "val args= scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]" - // Construct Class Implementation (e.g normBuilder) - var classImpl = ListBuffer[String]() ndarrayfunction.listOfArgs.foreach({ ndarrayArg => // var is a special word used to define variable in Scala, // need to changed to something else in order to make it work @@ -89,57 +84,40 @@ private[mxnet] object JavaNDArrayMacro { case "type" => "typeOf" case _ => ndarrayArg.argName } - if (ndarrayArg.isOptional && useParamObject) { - OptionArgDef += s"private var $currArgName : ${ndarrayArg.argType} = null" - val tempDef = s"def set${currArgName.capitalize}($currArgName : ${ndarrayArg.argType})" - val tempImpl = s"this.$currArgName = $currArgName\nthis" - classImpl += s"$tempDef = {$tempImpl}" - } else { - argDef += s"$currArgName : ${ndarrayArg.argType}" - } + if (useParamObject) currArgName = s"po.get${currArgName.capitalize}()" + argDef += s"$currArgName : ${ndarrayArg.argType}" // NDArray arg implementation val returnType = "org.apache.mxnet.javaapi.NDArray" val base = if (ndarrayArg.argType.equals(returnType)) { - s"args += $header$currArgName" + s"args += $currArgName" } else if (ndarrayArg.argType.equals(s"Array[$returnType]")){ - s"$header$currArgName.foreach(args+=_)" + s"$currArgName.foreach(args+=_)" } else { - "map(\"" + ndarrayArg.argName + "\") = " + header + currArgName + "map(\"" + ndarrayArg.argName + "\") = " + currArgName } impl.append( - if (ndarrayArg.isOptional) s"if ($header$currArgName != null) $base" + if (ndarrayArg.isOptional) s"if ($currArgName != null) $base" else base ) }) // add default out parameter + argDef += s"out: org.apache.mxnet.javaapi.NDArray" if (useParamObject) { - classImpl += - "def setOut(out : org.apache.mxnet.javaapi.NDArray) = {this.out = out\nthis}" + impl += "if (po.getOut() != null) map(\"out\") = po.getOut()" } else { - argDef += s"out: org.apache.mxnet.javaapi.NDArray" + impl += "if (out != null) map(\"out\") = out" } - impl += "if (" + header + "out != null) map(\"out\") = " + header + "out" - OptionArgDef += "var out : org.apache.mxnet.NDArray = null" - val returnType = "org.apache.mxnet.javaapi.NDArrayFuncReturn" + val returnType = "Array[org.apache.mxnet.javaapi.NDArray]" // scalastyle:off // Combine and build the function string - impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + - ndarrayfunction.name + "\", args.toSeq, map.toMap)" + impl += "val finalArr = org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + + ndarrayfunction.name + "\", args.toSeq, map.toMap).arr" + impl += "finalArr.map(ele => new NDArray(ele))" if (useParamObject) { - val classDef = - s"""class ${ndarrayfunction.name}Param(${argDef.mkString(",")}) - | extends ${ndarrayfunction.name}ParamBase(${argDef.mkString(",")}) { - | ${OptionArgDef.mkString("\n")} - | ${classImpl.mkString("\n")} - | def invoke() : $returnType = { - | ${impl.mkString("\n")} - | } - | }""".stripMargin - classDefs += c.parse(classDef).asInstanceOf[ClassDef] val funcDef = - s"""def ${ndarrayfunction.name}(po: ${ndarrayfunction.name}ParamBase): $returnType = { - | po.invoke() + s"""def ${ndarrayfunction.name}(po: ${ndarrayfunction.name}Param): $returnType = { + | ${impl.mkString("\n")} | }""".stripMargin functionDefs += c.parse(funcDef).asInstanceOf[DefDef] } else { @@ -151,12 +129,11 @@ private[mxnet] object JavaNDArrayMacro { } } - structGeneration(c)(functionDefs.toList, classDefs.toList, annottees : _*) + structGeneration(c)(functionDefs.toList, annottees : _*) } private def structGeneration(c: blackbox.Context) (funcDef : List[c.universe.DefDef], - classDef : List[c.universe.ClassDef], annottees: c.Expr[Any]*) : c.Expr[Any] = { import c.universe._ @@ -166,7 +143,7 @@ private[mxnet] object JavaNDArrayMacro { case ClassDef(mods, name, something, template) => val q = template match { case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef) + Template(superMaybe, emptyValDef, defs ++ funcDef) case ex => throw new IllegalArgumentException(s"Invalid template: $ex") } @@ -174,7 +151,7 @@ private[mxnet] object JavaNDArrayMacro { case ModuleDef(mods, name, template) => val q = template match { case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef) + Template(superMaybe, emptyValDef, defs ++ funcDef) case ex => throw new IllegalArgumentException(s"Invalid template: $ex") } From 7a8892c1d3ec9aced306775f4ad93fdf76d3ad82 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 8 Nov 2018 22:47:41 -0800 Subject: [PATCH 5/8] apply comments to Param --- .../scala/org/apache/mxnet/APIDocGenerator.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 8005e7e49206..34ea158e8ac0 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -116,9 +116,7 @@ private[mxnet] object APIDocGenerator{ val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_")) .filterNot(ele => notGenerated.contains(ele.name)) .map(absClassFunction => { - val scalaDoc = generateAPIDocFromBackend(absClassFunction) - val defBody = generateJavaAPISignature(absClassFunction) - s"$scalaDoc\n$defBody" + generateJavaAPISignature(absClassFunction) }) val packageName = "NDArrayBase" val packageDef = "package org.apache.mxnet.javaapi" @@ -225,6 +223,8 @@ private[mxnet] object APIDocGenerator{ }) val experimentalTag = "@Experimental" val returnType = "Array[NDArray]" + val scalaDoc = generateAPIDocFromBackend(func) + val scalaDocNoParam = generateAPIDocFromBackend(func, false) if(useParamObject) { classDef += s"""private var out : org.apache.mxnet.NDArray = null @@ -234,14 +234,17 @@ private[mxnet] object APIDocGenerator{ | } | def getOut() = this.out | """.stripMargin - s"""$experimentalTag + s"""$scalaDocNoParam + | $experimentalTag | def ${func.name}(po: ${func.name}Param) : $returnType + | $scalaDoc | class ${func.name}Param(${argDef.mkString(",")}) { | ${classDef.mkString("\n ")} | }""".stripMargin } else { argDef += "out : NDArray" - s"""$experimentalTag + s"""$scalaDoc + |$experimentalTag | def ${func.name}(${argDef.mkString(", ")}) : $returnType | """.stripMargin } From fc202824a102fbdb7aa6a82a2c2767b268e0702d Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 9 Nov 2018 17:00:56 -0800 Subject: [PATCH 6/8] clean up param build --- .../src/main/scala/org/apache/mxnet/APIDocGenerator.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 34ea158e8ac0..3bd4a43cf2bc 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -225,6 +225,10 @@ private[mxnet] object APIDocGenerator{ val returnType = "Array[NDArray]" val scalaDoc = generateAPIDocFromBackend(func) val scalaDocNoParam = generateAPIDocFromBackend(func, false) + val params = func.listOfArgs.map({ absClassArg => + val currArgName = safetyNameCheck(absClassArg.argName) + s" * @param $currArgName\t\t${absClassArg.argDesc}" + }) if(useParamObject) { classDef += s"""private var out : org.apache.mxnet.NDArray = null @@ -237,7 +241,9 @@ private[mxnet] object APIDocGenerator{ s"""$scalaDocNoParam | $experimentalTag | def ${func.name}(po: ${func.name}Param) : $returnType - | $scalaDoc + | /** + | ${params.mkString("\n")} + | */ | class ${func.name}Param(${argDef.mkString(",")}) { | ${classDef.mkString("\n ")} | }""".stripMargin From 65105a7c6d1869de7ef0372e7713b8ae01704b01 Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 9 Nov 2018 17:10:21 -0800 Subject: [PATCH 7/8] change on the comments --- .../main/scala/org/apache/mxnet/APIDocGenerator.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 3bd4a43cf2bc..e2efa676322a 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -204,18 +204,23 @@ private[mxnet] object APIDocGenerator{ val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2 var argDef = ListBuffer[String]() var classDef = ListBuffer[String]() + var requiredParam = ListBuffer[String]() func.listOfArgs.foreach(absClassArg => { val currArgName = safetyNameCheck(absClassArg.argName) // scalastyle:off if (absClassArg.isOptional && useParamObject) { classDef += s"""private var $currArgName: ${absClassArg.argType} = null + |/** + | * @param $currArgName\t\t${absClassArg.argDesc} + | */ |def set${currArgName.capitalize}($currArgName : ${absClassArg.argType}): ${func.name}Param = { | this.$currArgName = $currArgName | this | }""".stripMargin } else { + requiredParam += s" * @param $currArgName\t\t${absClassArg.argDesc}" argDef += s"$currArgName : ${absClassArg.argType}" } classDef += s"def get${currArgName.capitalize}() = this.$currArgName" @@ -225,10 +230,6 @@ private[mxnet] object APIDocGenerator{ val returnType = "Array[NDArray]" val scalaDoc = generateAPIDocFromBackend(func) val scalaDocNoParam = generateAPIDocFromBackend(func, false) - val params = func.listOfArgs.map({ absClassArg => - val currArgName = safetyNameCheck(absClassArg.argName) - s" * @param $currArgName\t\t${absClassArg.argDesc}" - }) if(useParamObject) { classDef += s"""private var out : org.apache.mxnet.NDArray = null @@ -242,7 +243,7 @@ private[mxnet] object APIDocGenerator{ | $experimentalTag | def ${func.name}(po: ${func.name}Param) : $returnType | /** - | ${params.mkString("\n")} + | ${requiredParam.mkString("\n")} | */ | class ${func.name}Param(${argDef.mkString(",")}) { | ${classDef.mkString("\n ")} From 19088d6c47baa9db975dedf8c3323c6009b175e4 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 12 Nov 2018 10:38:55 -0800 Subject: [PATCH 8/8] add one description line --- .../macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index e2efa676322a..f2326868e8e7 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -243,6 +243,7 @@ private[mxnet] object APIDocGenerator{ | $experimentalTag | def ${func.name}(po: ${func.name}Param) : $returnType | /** + | * This Param Object is specifically used for ${func.name} | ${requiredParam.mkString("\n")} | */ | class ${func.name}Param(${argDef.mkString(",")}) {