Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Symbol Random api
Browse files Browse the repository at this point in the history
  • Loading branch information
mdespriee committed Sep 8, 2018
1 parent e290623 commit cb574c5
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,8 @@ object Symbol extends SymbolBase {

val api = SymbolAPI

val random = SymbolRandomAPI

def pow(sym1: Symbol, sym2: Symbol): Symbol = {
Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2))
}
Expand Down
10 changes: 10 additions & 0 deletions scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,13 @@ object SymbolAPI extends SymbolAPIBase {
Symbol.createSymbolGeneral("Custom", name, attr, Seq(), map.toMap)
}
}

@AddSymbolRandomAPIs(false)
/**
* typesafe Symbol API: Symbol.random._
* Main code will be generated during compile time through Macros
*/
object SymbolRandomAPI extends SymbolRandomAPIBase {

}

Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@ package org.apache.mxnet
import org.scalatest.{BeforeAndAfterAll, FunSuite}

class SymbolSuite extends FunSuite with BeforeAndAfterAll {

test("random module - normal") {
val loc = Symbol.Variable("loc")
val scale = Symbol.Variable("scale")
val rnd = Symbol.random.sample_normal(mu = Some(loc), sigma = Some(scale),
shape = Some(Shape(2, 2)))
val rnd2 = Symbol.random.random_normal(loc = Some(1f), scale = Some(2f),
shape = Some(Shape(2, 2)))
// scalastyle:off println
println(s"Symbol.random.normal Symbol args debug info: ${rnd.debugStr}")
println(s"Symbol.random.normal Symbol args debug info: ${rnd2.debugStr}")
// scalastyle:on println
}

test("symbol compose") {
val data = Symbol.Variable("data")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ private[mxnet] object APIDocGenerator{
val hashCollector = ListBuffer[String]()
hashCollector += absClassGen(FILE_PATH, true)
hashCollector += absClassGen(FILE_PATH, false)
hashCollector += absRndClassGen(FILE_PATH, true)
// hashCollector += absClassGen(FILE_PATH, false) // TODO random NDArray
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
val finalHash = hashCollector.mkString("\n")
Expand All @@ -52,13 +54,39 @@ private[mxnet] object APIDocGenerator{
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}

def absRndClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
// TODO: Add Filter to the same location in case of refactor
val absFuncs = absClassFunctions.filter(f => f.name.startsWith("sample") || f.name.startsWith("random"))
.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
val defBody = generateAPISignature(absClassFunction, isSymbol)
s"$scalaDoc\n$defBody"
})
val packageName = if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase"
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
val scalaStyle = "// scalastyle:off"
val packageDef = "package org.apache.mxnet"
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
MD5Generator(finalStr)
}

def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
// Defines Operators that should not generated
val notGenerated = Set("Custom")
// TODO: Add Filter to the same location in case of refactor
val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
val absFuncs = absClassFunctions
.filterNot(_.name.startsWith("_"))
.filterNot(_.name.startsWith("sample"))
.filterNot(_.name.startsWith("random"))
.filterNot(ele => notGenerated.contains(ele.name))
.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
Expand Down
111 changes: 68 additions & 43 deletions scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends StaticAnnotation
private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.typeSafeAPIDefs
}

private[mxnet] class AddSymbolRandomAPIs(isContrib: Boolean) extends StaticAnnotation {
private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.typedRandomAPIDefs
}


private[mxnet] object SymbolImplMacros {
case class SymbolArg(argName: String, argType: String, isOptional : Boolean)
case class SymbolFunction(name: String, listOfArgs: List[SymbolArg])
Expand All @@ -43,6 +48,9 @@ private[mxnet] object SymbolImplMacros {
def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
typedAPIImpl(c)(annottees: _*)
}
def typedRandomAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
typedRandomAPIImpl(c)(annottees: _*)
}
// scalastyle:on havetype

private val symbolFunctions: List[SymbolFunction] = initSymbolModule()
Expand All @@ -63,7 +71,6 @@ private[mxnet] object SymbolImplMacros {
else symbolFunctions.filter(!_.name.startsWith("_"))
}


val functionDefs = newSymbolFunctions map { symbolfunction =>
val funcName = symbolfunction.name
val tName = TermName(funcName)
Expand All @@ -79,10 +86,23 @@ private[mxnet] object SymbolImplMacros {
structGeneration(c)(functionDefs, annottees : _*)
}

/**
* Implementation for Dynamic typed API Symbol.random.<functioname>
*/
private def typedRandomAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._

val rndFunctions =
symbolFunctions.filter(f => f.name.startsWith("sample") || f.name.startsWith("random"))

val functionDefs = rndFunctions.map(f => buildTypedFunction(c)(f))
structGeneration(c)(functionDefs, annottees: _*)
}

/**
* Implementation for Dynamic typed API Symbol.api.<functioname>
*/
private def typedAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = {
private def typedAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._

val isContrib: Boolean = c.prefix.tree match {
Expand All @@ -97,34 +117,41 @@ private[mxnet] object SymbolImplMacros {
val newSymbolFunctions = {
if (isContrib) symbolFunctions.filter(
func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
else symbolFunctions.filter(!_.name.startsWith("_"))
else symbolFunctions.filter(f => !f.name.startsWith("_") &&
!f.name.startsWith("sample") && !f.name.startsWith("random"))
}.filterNot(ele => notGenerated.contains(ele.name))

val functionDefs = newSymbolFunctions map { symbolfunction =>
val functionDefs = newSymbolFunctions.map(f => buildTypedFunction(c)(f))
structGeneration(c)(functionDefs, annottees : _*)
}

// Construct argument field
var argDef = ListBuffer[String]()
// Construct Implementation field
var impl = ListBuffer[String]()
impl += "val map = scala.collection.mutable.Map[String, Any]()"
impl += "var args = Seq[org.apache.mxnet.Symbol]()"
symbolfunction.listOfArgs.foreach({ symbolarg =>
// var is a special word used to define variable in Scala,
// need to changed to something else in order to make it work
val currArgName = symbolarg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case default => symbolarg.argName
}
if (symbolarg.isOptional) {
argDef += s"${currArgName} : Option[${symbolarg.argType}] = None"
}
else {
argDef += s"${currArgName} : ${symbolarg.argType}"
}
// Symbol arg implementation
val returnType = "org.apache.mxnet.Symbol"
val base =
private def buildTypedFunction(c: blackbox.Context)
(symbolfunction: SymbolFunction): c.universe.DefDef = {
import c.universe._

// Construct argument field
var argDef = ListBuffer[String]()
// Construct Implementation field
var impl = ListBuffer[String]()
impl += "val map = scala.collection.mutable.Map[String, Any]()"
impl += "var args = Seq[org.apache.mxnet.Symbol]()"
symbolfunction.listOfArgs.foreach({ symbolarg =>
// var is a special word used to define variable in Scala,
// need to changed to something else in order to make it work
val currArgName = symbolarg.argName match {
case "var" => "vari"
case "type" => "typeOf"
case default => symbolarg.argName
}
if (symbolarg.isOptional) {
argDef += s"${currArgName} : Option[${symbolarg.argType}] = None"
}
else {
argDef += s"${currArgName} : ${symbolarg.argType}"
}
// Symbol arg implementation
val returnType = "org.apache.mxnet.Symbol"
val base =
if (symbolarg.argType.equals(s"Array[$returnType]")) {
if (symbolarg.isOptional) s"if (!$currArgName.isEmpty) args = $currArgName.get.toSeq"
else s"args = $currArgName.toSeq"
Expand All @@ -137,22 +164,20 @@ private[mxnet] object SymbolImplMacros {
else "map(\"" + symbolarg.argName + "\"" + s") = $currArgName"
}

impl += base
})
argDef += "name : String = null"
argDef += "attr : Map[String, String] = null"
// scalastyle:off
// TODO: Seq() here allows user to place Symbols rather than normal arguments to run, need to fix if old API deprecated
impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, args, map.toMap)"
// scalastyle:on
// Combine and build the function string
val returnType = "org.apache.mxnet.Symbol"
var finalStr = s"def ${symbolfunction.name}"
finalStr += s" (${argDef.mkString(",")}) : $returnType"
finalStr += s" = {${impl.mkString("\n")}}"
c.parse(finalStr).asInstanceOf[DefDef]
}
structGeneration(c)(functionDefs, annottees : _*)
impl += base
})
argDef += "name : String = null"
argDef += "attr : Map[String, String] = null"
// scalastyle:off
// TODO: Seq() here allows user to place Symbols rather than normal arguments to run, need to fix if old API deprecated
impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, args, map.toMap)"
// scalastyle:on
// Combine and build the function string
val returnType = "org.apache.mxnet.Symbol"
var finalStr = s"def ${symbolfunction.name}"
finalStr += s" (${argDef.mkString(",")}) : $returnType"
finalStr += s" = {${impl.mkString("\n")}}"
c.parse(finalStr).asInstanceOf[DefDef]
}

/**
Expand Down

0 comments on commit cb574c5

Please sign in to comment.