diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala index d4e67f73408e..cdcc292ada63 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala @@ -385,17 +385,3 @@ class NDArray(val nd : org.apache.mxnet.NDArray ) { override def equals(obj: Any): Boolean = nd.equals(obj) override def hashCode(): Int = nd.hashCode } - -object NDArrayFuncReturn { - implicit def toNDFuncReturn(javaFunReturn : NDArrayFuncReturn) - : org.apache.mxnet.NDArrayFuncReturn = javaFunReturn.ndFuncReturn - implicit def toJavaNDFuncReturn(ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) - : NDArrayFuncReturn = new NDArrayFuncReturn(ndFuncReturn) -} - -private[mxnet] class NDArrayFuncReturn(val ndFuncReturn : org.apache.mxnet.NDArrayFuncReturn) { - def head : NDArray = ndFuncReturn.head - def get : NDArray = ndFuncReturn.get - def apply(i : Int) : NDArray = ndFuncReturn.apply(i) - // TODO: Add JavaNDArray operational stuff -} diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java index a9bad83f62d6..2659b7848bc6 100644 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java @@ -19,9 +19,9 @@ import org.junit.Test; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.apache.mxnet.javaapi.NDArrayBase.*; import static org.junit.Assert.assertTrue; @@ -71,7 +71,7 @@ public void testGenerated(){ NDArray$ NDArray = NDArray$.MODULE$; float[] arr = new float[]{1.0f, 2.0f, 3.0f}; NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new Context("cpu", 0)); - float result = NDArray.norm(nd).invoke().get().toArray()[0]; + float result = NDArray.norm(NDArray.new normParam(nd))[0].toArray()[0]; float cal = 0.0f; for (float ele : arr) { cal += ele * ele; @@ -79,7 +79,7 @@ public void testGenerated(){ cal = (float) Math.sqrt(cal); assertTrue(Math.abs(result - cal) < 1e-5); NDArray dotResult = new NDArray(new float[]{0}, new Shape(new int[]{1}), new Context("cpu", 0)); - NDArray.dot(nd, nd).setout(dotResult).invoke().get(); + NDArray.dot(NDArray.new dotParam(nd, nd).setOut(dotResult)); assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f})); } } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index 44d47a2099d5..f2326868e8e7 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -116,9 +116,7 @@ private[mxnet] object APIDocGenerator{ val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_")) .filterNot(ele => notGenerated.contains(ele.name)) .map(absClassFunction => { - val scalaDoc = generateAPIDocFromBackend(absClassFunction) - val defBody = generateJavaAPISignature(absClassFunction) - s"$scalaDoc\n$defBody" + generateJavaAPISignature(absClassFunction) }) val packageName = "NDArrayBase" val packageDef = "package org.apache.mxnet.javaapi" @@ -203,27 +201,61 @@ private[mxnet] object APIDocGenerator{ } def generateJavaAPISignature(func : absClassFunction) : String = { + val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2 var argDef = ListBuffer[String]() var classDef = ListBuffer[String]() + var requiredParam = ListBuffer[String]() func.listOfArgs.foreach(absClassArg => { val currArgName = safetyNameCheck(absClassArg.argName) // scalastyle:off - if (absClassArg.isOptional) { - classDef += s"def set${absClassArg.argName}(${absClassArg.argName} : ${absClassArg.argType}) : ${func.name}BuilderBase" + if (absClassArg.isOptional && useParamObject) { + classDef += + s"""private var $currArgName: ${absClassArg.argType} = null + |/** + | * @param $currArgName\t\t${absClassArg.argDesc} + | */ + |def set${currArgName.capitalize}($currArgName : ${absClassArg.argType}): ${func.name}Param = { + | this.$currArgName = $currArgName + | this + | }""".stripMargin } else { + requiredParam += s" * @param $currArgName\t\t${absClassArg.argDesc}" argDef += s"$currArgName : ${absClassArg.argType}" } + classDef += s"def get${currArgName.capitalize}() = this.$currArgName" // scalastyle:on }) - classDef += s"def setout(out : NDArray) : ${func.name}BuilderBase" - classDef += s"def invoke() : org.apache.mxnet.javaapi.NDArrayFuncReturn" val experimentalTag = "@Experimental" - // scalastyle:off - var finalStr = s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : ${func.name}BuilderBase\n" - // scalastyle:on - finalStr += s"abstract class ${func.name}BuilderBase {\n ${classDef.mkString("\n ")}\n}" - finalStr + val returnType = "Array[NDArray]" + val scalaDoc = generateAPIDocFromBackend(func) + val scalaDocNoParam = generateAPIDocFromBackend(func, false) + if(useParamObject) { + classDef += + s"""private var out : org.apache.mxnet.NDArray = null + |def setOut(out : NDArray) : ${func.name}Param = { + | this.out = out + | this + | } + | def getOut() = this.out + | """.stripMargin + s"""$scalaDocNoParam + | $experimentalTag + | def ${func.name}(po: ${func.name}Param) : $returnType + | /** + | * This Param Object is specifically used for ${func.name} + | ${requiredParam.mkString("\n")} + | */ + | class ${func.name}Param(${argDef.mkString(",")}) { + | ${classDef.mkString("\n ")} + | }""".stripMargin + } else { + argDef += "out : NDArray" + s"""$scalaDoc + |$experimentalTag + | def ${func.name}(${argDef.mkString(", ")}) : $returnType + | """.stripMargin + } } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala index d5be97b501c5..2d1827038afc 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala @@ -68,18 +68,14 @@ private[mxnet] object JavaNDArrayMacro { newNDArrayFunctions.foreach { ndarrayfunction => + val useParamObject = ndarrayfunction.listOfArgs.count(arg => arg.isOptional) >= 2 // Construct argument field with all required args var argDef = ListBuffer[String]() - // Construct Optional Arg - var OptionArgDef = ListBuffer[String]() // Construct function Implementation field (e.g norm) var impl = ListBuffer[String]() impl += "val map = scala.collection.mutable.Map[String, Any]()" - // scalastyle:off - impl += "val args= scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]" - // scalastyle:on - // Construct Class Implementation (e.g normBuilder) - var classImpl = ListBuffer[String]() + impl += + "val args= scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]" ndarrayfunction.listOfArgs.foreach({ ndarrayArg => // var is a special word used to define variable in Scala, // need to changed to something else in order to make it work @@ -88,55 +84,56 @@ private[mxnet] object JavaNDArrayMacro { case "type" => "typeOf" case _ => ndarrayArg.argName } - if (ndarrayArg.isOptional) { - OptionArgDef += s"private var $currArgName : ${ndarrayArg.argType} = null" - val tempDef = s"def set$currArgName($currArgName : ${ndarrayArg.argType})" - val tempImpl = s"this.$currArgName = $currArgName\nthis" - classImpl += s"$tempDef = {$tempImpl}" - } else { - argDef += s"$currArgName : ${ndarrayArg.argType}" - } + if (useParamObject) currArgName = s"po.get${currArgName.capitalize}()" + argDef += s"$currArgName : ${ndarrayArg.argType}" // NDArray arg implementation val returnType = "org.apache.mxnet.javaapi.NDArray" val base = if (ndarrayArg.argType.equals(returnType)) { - s"args += this.$currArgName" + s"args += $currArgName" } else if (ndarrayArg.argType.equals(s"Array[$returnType]")){ - s"this.$currArgName.foreach(args+=_)" + s"$currArgName.foreach(args+=_)" } else { - "map(\"" + ndarrayArg.argName + "\") = this." + currArgName + "map(\"" + ndarrayArg.argName + "\") = " + currArgName } impl.append( - if (ndarrayArg.isOptional) s"if (this.$currArgName != null) $base" + if (ndarrayArg.isOptional) s"if ($currArgName != null) $base" else base ) }) // add default out parameter - classImpl += - "def setout(out : org.apache.mxnet.javaapi.NDArray) = {this.out = out\nthis}" - impl += "if (this.out != null) map(\"out\") = this.out" - OptionArgDef += "private var out : org.apache.mxnet.NDArray = null" - val returnType = "org.apache.mxnet.javaapi.NDArrayFuncReturn" + argDef += s"out: org.apache.mxnet.javaapi.NDArray" + if (useParamObject) { + impl += "if (po.getOut() != null) map(\"out\") = po.getOut()" + } else { + impl += "if (out != null) map(\"out\") = out" + } + val returnType = "Array[org.apache.mxnet.javaapi.NDArray]" // scalastyle:off // Combine and build the function string - impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + ndarrayfunction.name + "\", args.toSeq, map.toMap)" - val classDef = s"class ${ndarrayfunction.name}Builder(${argDef.mkString(",")}) extends ${ndarrayfunction.name}BuilderBase" - val classBody = s"${OptionArgDef.mkString("\n")}\n${classImpl.mkString("\n")}\ndef invoke() : $returnType = {${impl.mkString("\n")}}" - val classFinal = s"$classDef {$classBody}" - val functionDef = s"def ${ndarrayfunction.name} (${argDef.mkString(",")})" - val functionBody = s"new ${ndarrayfunction.name}Builder(${argDef.map(_.split(":")(0)).mkString(",")})" - val functionFinal = s"$functionDef : ${ndarrayfunction.name}BuilderBase = $functionBody" - // scalastyle:on - functionDefs += c.parse(functionFinal).asInstanceOf[DefDef] - classDefs += c.parse(classFinal).asInstanceOf[ClassDef] + impl += "val finalArr = org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + + ndarrayfunction.name + "\", args.toSeq, map.toMap).arr" + impl += "finalArr.map(ele => new NDArray(ele))" + if (useParamObject) { + val funcDef = + s"""def ${ndarrayfunction.name}(po: ${ndarrayfunction.name}Param): $returnType = { + | ${impl.mkString("\n")} + | }""".stripMargin + functionDefs += c.parse(funcDef).asInstanceOf[DefDef] + } else { + val funcDef = + s"""def ${ndarrayfunction.name}(${argDef.mkString(",")}): $returnType = { + | ${impl.mkString("\n")} + | }""".stripMargin + functionDefs += c.parse(funcDef).asInstanceOf[DefDef] + } } - structGeneration(c)(functionDefs.toList, classDefs.toList, annottees : _*) + structGeneration(c)(functionDefs.toList, annottees : _*) } private def structGeneration(c: blackbox.Context) (funcDef : List[c.universe.DefDef], - classDef : List[c.universe.ClassDef], annottees: c.Expr[Any]*) : c.Expr[Any] = { import c.universe._ @@ -146,7 +143,7 @@ private[mxnet] object JavaNDArrayMacro { case ClassDef(mods, name, something, template) => val q = template match { case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef) + Template(superMaybe, emptyValDef, defs ++ funcDef) case ex => throw new IllegalArgumentException(s"Invalid template: $ex") } @@ -154,7 +151,7 @@ private[mxnet] object JavaNDArrayMacro { case ModuleDef(mods, name, template) => val q = template match { case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef) + Template(superMaybe, emptyValDef, defs ++ funcDef) case ex => throw new IllegalArgumentException(s"Invalid template: $ex") }