diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index e3e1a320358e..20d14eaf162f 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -835,6 +835,8 @@ object Symbol extends SymbolBase { val api = SymbolAPI + val random = SymbolRandomAPI + def pow(sym1: Symbol, sym2: Symbol): Symbol = { Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2)) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala index 1bfb0559cf96..57102ca41f9c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala @@ -32,3 +32,13 @@ object SymbolAPI extends SymbolAPIBase { Symbol.createSymbolGeneral("Custom", name, attr, Seq(), map.toMap) } } + +@AddSymbolRandomAPIs(false) +/** + * typesafe Symbol API: Symbol.random._ + * Main code will be generated during compile time through Macros + */ +object SymbolRandomAPI extends SymbolRandomAPIBase { + +} + diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala index ebb61d7d4bfb..d3b780106cce 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala @@ -20,6 +20,20 @@ package org.apache.mxnet import org.scalatest.{BeforeAndAfterAll, FunSuite} class SymbolSuite extends FunSuite with BeforeAndAfterAll { + + test("random module - normal") { + val loc = Symbol.Variable("loc") + val scale = Symbol.Variable("scale") + val rnd = Symbol.random.sample_normal(mu = Some(loc), sigma = Some(scale), + shape = Some(Shape(2, 2))) + val rnd2 = Symbol.random.random_normal(loc = Some(1f), scale = Some(2f), + shape = Some(Shape(2, 2))) + // scalastyle:off println + println(s"Symbol.random.normal Symbol args debug info: ${rnd.debugStr}") + println(s"Symbol.random.normal Symbol args debug info: ${rnd2.debugStr}") + // scalastyle:on println + } + test("symbol compose") { val data = Symbol.Variable("data") diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala index b4efa659443c..02d0435787c7 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala @@ -40,6 +40,8 @@ private[mxnet] object APIDocGenerator{ val hashCollector = ListBuffer[String]() hashCollector += absClassGen(FILE_PATH, true) hashCollector += absClassGen(FILE_PATH, false) + hashCollector += absRndClassGen(FILE_PATH, true) +// hashCollector += absClassGen(FILE_PATH, false) // TODO random NDArray hashCollector += nonTypeSafeClassGen(FILE_PATH, true) hashCollector += nonTypeSafeClassGen(FILE_PATH, false) val finalHash = hashCollector.mkString("\n") @@ -52,13 +54,39 @@ private[mxnet] object APIDocGenerator{ org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest) } + def absRndClassGen(FILE_PATH : String, isSymbol : Boolean) : String = { + // scalastyle:off + val absClassFunctions = getSymbolNDArrayMethods(isSymbol) + // TODO: Add Filter to the same location in case of refactor + val absFuncs = absClassFunctions.filter(f => f.name.startsWith("sample") || f.name.startsWith("random")) + .map(absClassFunction => { + val scalaDoc = generateAPIDocFromBackend(absClassFunction) + val defBody = generateAPISignature(absClassFunction, isSymbol) + s"$scalaDoc\n$defBody" + }) + val packageName = if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase" + val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n" + val scalaStyle = "// scalastyle:off" + val packageDef = "package org.apache.mxnet" + val imports = "import org.apache.mxnet.annotation.Experimental" + val absClassDef = s"abstract class $packageName" + val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}" + val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala")) + pw.write(finalStr) + pw.close() + MD5Generator(finalStr) + } + def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = { // scalastyle:off val absClassFunctions = getSymbolNDArrayMethods(isSymbol) // Defines Operators that should not generated val notGenerated = Set("Custom") // TODO: Add Filter to the same location in case of refactor - val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_")) + val absFuncs = absClassFunctions + .filterNot(_.name.startsWith("_")) + .filterNot(_.name.startsWith("sample")) + .filterNot(_.name.startsWith("random")) .filterNot(ele => notGenerated.contains(ele.name)) .map(absClassFunction => { val scalaDoc = generateAPIDocFromBackend(absClassFunction) diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index 42aa11781d8f..4544046d01bc 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -32,6 +32,11 @@ private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends StaticAnnotation private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.typeSafeAPIDefs } +private[mxnet] class AddSymbolRandomAPIs(isContrib: Boolean) extends StaticAnnotation { + private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.typedRandomAPIDefs +} + + private[mxnet] object SymbolImplMacros { case class SymbolArg(argName: String, argType: String, isOptional : Boolean) case class SymbolFunction(name: String, listOfArgs: List[SymbolArg]) @@ -43,6 +48,9 @@ private[mxnet] object SymbolImplMacros { def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { typedAPIImpl(c)(annottees: _*) } + def typedRandomAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { + typedRandomAPIImpl(c)(annottees: _*) + } // scalastyle:on havetype private val symbolFunctions: List[SymbolFunction] = initSymbolModule() @@ -63,7 +71,6 @@ private[mxnet] object SymbolImplMacros { else symbolFunctions.filter(!_.name.startsWith("_")) } - val functionDefs = newSymbolFunctions map { symbolfunction => val funcName = symbolfunction.name val tName = TermName(funcName) @@ -79,10 +86,23 @@ private[mxnet] object SymbolImplMacros { structGeneration(c)(functionDefs, annottees : _*) } + /** + * Implementation for Dynamic typed API Symbol.random. + */ + private def typedRandomAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { + import c.universe._ + + val rndFunctions = + symbolFunctions.filter(f => f.name.startsWith("sample") || f.name.startsWith("random")) + + val functionDefs = rndFunctions.map(f => buildTypedFunction(c)(f)) + structGeneration(c)(functionDefs, annottees: _*) + } + /** * Implementation for Dynamic typed API Symbol.api. */ - private def typedAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = { + private def typedAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ val isContrib: Boolean = c.prefix.tree match { @@ -97,34 +117,41 @@ private[mxnet] object SymbolImplMacros { val newSymbolFunctions = { if (isContrib) symbolFunctions.filter( func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) - else symbolFunctions.filter(!_.name.startsWith("_")) + else symbolFunctions.filter(f => !f.name.startsWith("_") && + !f.name.startsWith("sample") && !f.name.startsWith("random")) }.filterNot(ele => notGenerated.contains(ele.name)) - val functionDefs = newSymbolFunctions map { symbolfunction => + val functionDefs = newSymbolFunctions.map(f => buildTypedFunction(c)(f)) + structGeneration(c)(functionDefs, annottees : _*) + } - // Construct argument field - var argDef = ListBuffer[String]() - // Construct Implementation field - var impl = ListBuffer[String]() - impl += "val map = scala.collection.mutable.Map[String, Any]()" - impl += "var args = Seq[org.apache.mxnet.Symbol]()" - symbolfunction.listOfArgs.foreach({ symbolarg => - // var is a special word used to define variable in Scala, - // need to changed to something else in order to make it work - val currArgName = symbolarg.argName match { - case "var" => "vari" - case "type" => "typeOf" - case default => symbolarg.argName - } - if (symbolarg.isOptional) { - argDef += s"${currArgName} : Option[${symbolarg.argType}] = None" - } - else { - argDef += s"${currArgName} : ${symbolarg.argType}" - } - // Symbol arg implementation - val returnType = "org.apache.mxnet.Symbol" - val base = + private def buildTypedFunction(c: blackbox.Context) + (symbolfunction: SymbolFunction): c.universe.DefDef = { + import c.universe._ + + // Construct argument field + var argDef = ListBuffer[String]() + // Construct Implementation field + var impl = ListBuffer[String]() + impl += "val map = scala.collection.mutable.Map[String, Any]()" + impl += "var args = Seq[org.apache.mxnet.Symbol]()" + symbolfunction.listOfArgs.foreach({ symbolarg => + // var is a special word used to define variable in Scala, + // need to changed to something else in order to make it work + val currArgName = symbolarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case default => symbolarg.argName + } + if (symbolarg.isOptional) { + argDef += s"${currArgName} : Option[${symbolarg.argType}] = None" + } + else { + argDef += s"${currArgName} : ${symbolarg.argType}" + } + // Symbol arg implementation + val returnType = "org.apache.mxnet.Symbol" + val base = if (symbolarg.argType.equals(s"Array[$returnType]")) { if (symbolarg.isOptional) s"if (!$currArgName.isEmpty) args = $currArgName.get.toSeq" else s"args = $currArgName.toSeq" @@ -137,22 +164,20 @@ private[mxnet] object SymbolImplMacros { else "map(\"" + symbolarg.argName + "\"" + s") = $currArgName" } - impl += base - }) - argDef += "name : String = null" - argDef += "attr : Map[String, String] = null" - // scalastyle:off - // TODO: Seq() here allows user to place Symbols rather than normal arguments to run, need to fix if old API deprecated - impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, args, map.toMap)" - // scalastyle:on - // Combine and build the function string - val returnType = "org.apache.mxnet.Symbol" - var finalStr = s"def ${symbolfunction.name}" - finalStr += s" (${argDef.mkString(",")}) : $returnType" - finalStr += s" = {${impl.mkString("\n")}}" - c.parse(finalStr).asInstanceOf[DefDef] - } - structGeneration(c)(functionDefs, annottees : _*) + impl += base + }) + argDef += "name : String = null" + argDef += "attr : Map[String, String] = null" + // scalastyle:off + // TODO: Seq() here allows user to place Symbols rather than normal arguments to run, need to fix if old API deprecated + impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, args, map.toMap)" + // scalastyle:on + // Combine and build the function string + val returnType = "org.apache.mxnet.Symbol" + var finalStr = s"def ${symbolfunction.name}" + finalStr += s" (${argDef.mkString(",")}) : $returnType" + finalStr += s" = {${impl.mkString("\n")}}" + c.parse(finalStr).asInstanceOf[DefDef] } /**