Skip to content

Commit

Permalink
FullyConnected, Activate and Group Symbol creation
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu committed Jan 3, 2016
1 parent a41016e commit e1aa386
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@ class AttrScope(attr: Map[String, String] = Map.empty) {
* @param userDefinedAttr The attribute passed in by user during symbol creation.
* @return Updated attributes to add other scope related attributes.
*/
def get(userDefinedAttr: Map[String, String]): Map[String, String] = {
if (userDefinedAttr != null) {
attr ++ userDefinedAttr
} else {
attr
}
def get(userDefinedAttr: Option[Map[String, String]]): Map[String, String] = {
_attr ++ userDefinedAttr.getOrElse(Map.empty[String, String])
}

def withScope[T](body: => T): T = {
Expand Down
7 changes: 7 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,11 @@ class LibInfo {
key: String,
ret: RefString,
success: RefInt): Int
@native def mxSymbolListArguments(handle: SymbolHandle,
arguments: ArrayBuffer[String]): Int
@native def mxSymbolCopy(handle: SymbolHandle, clonedHandle: SymbolHandleRef): Int
@native def mxSymbolListOutputs(handle: SymbolHandle,
outputs: ArrayBuffer[String]): Int
@native def mxSymbolCreateGroup(handles: Array[SymbolHandle], out: SymbolHandleRef): Int
@native def mxSymbolPrint(handle: SymbolHandle, str: RefString): Int
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ class NameManager {
* @param hint : A hint string, which can be used to generate name.
* @return A canonical name for the user.
*/
def get(name: String, hint: String): String = {
if (name != null) {
name
} else {
def get(name: Option[String], hint: String): String = {
name.getOrElse {
if (!counter.contains(hint)) {
counter(hint) = 0
}
Expand Down
170 changes: 136 additions & 34 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,40 @@ package ml.dmlc.mxnet
import ml.dmlc.mxnet.Base._
import org.slf4j.LoggerFactory

import scala.collection.mutable.ListBuffer
import scala.collection.mutable.{ArrayBuffer, ListBuffer}

/**
* Symbolic configuration API of mxnet.
* @author Yizhi Liu
*/
class Symbol(private[mxnet] val handle: SymbolHandle) {
def +(other: Symbol): Symbol = Symbol.creator("_Plus", other)
def +(other: Int): Symbol = ???
def +(other: Float): Symbol = ???
def +(other: Double): Symbol = ???
def +(other: Symbol): Symbol = Symbol.create("_Plus", other)

override def clone(): Symbol = {
val clonedHandle = new SymbolHandleRef
checkCall(_LIB.mxSymbolCopy(handle, clonedHandle))
new Symbol(clonedHandle.value)
}

/**
* List all the arguments in the symbol.
* @return Array of all the arguments.
*/
def listArguments(): Array[String] = ???
def listArguments(): Array[String] = {
val arr = ArrayBuffer.empty[String]
checkCall(_LIB.mxSymbolListArguments(handle, arr))
arr.toArray
}

/**
* List all outputs in the symbol.
* @return : List of all the outputs.
*/
def listOutputs(): Array[String] = {
val arr = ArrayBuffer.empty[String]
checkCall(_LIB.mxSymbolListOutputs(handle, arr))
arr.toArray
}

/**
* List all auxiliary states in the symbol.
Expand Down Expand Up @@ -49,6 +66,28 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
}
}

/**
* Invoke symbol as function on inputs.
* @param name resulting symbol name
* @param symbols provide named symbols
* @return the resulting symbol
*/
def apply(name: String, symbols: Map[String, Symbol]): Symbol = {
val s = clone()
s.compose(name, symbols)
s
}

/**
* Get a debug string.
* @return Debug string of the symbol.
*/
def debugStr: String = {
val str = new RefString
checkCall(_LIB.mxSymbolPrint(handle, str))
str.value
}

// Set the attribute of the symbol.
private def setAttr(attr: Map[String, String]): Unit = {
attr.foreach { case (key, value) =>
Expand All @@ -59,6 +98,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
/**
* Compose symbol on inputs.
* This call mutates the current symbol.
* @param name resulting symbol name
* @param symbols provide positional arguments
* @return the resulting symbol
*/
Expand All @@ -70,7 +110,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
private def compose(name: String, symbols: Map[String, Symbol]): Unit = {
val keys = symbols.keys.toArray
val args = symbols.values.map(_.handle).toArray
checkCall(_LIB.mxSymbolCompose(handle, name, null, args))
checkCall(_LIB.mxSymbolCompose(handle, name, keys, args))
}
}

Expand All @@ -88,10 +128,38 @@ object Symbol {
val handle = new SymbolHandleRef
checkCall(_LIB.mxSymbolCreateVariable(name, handle))
val sym = new Symbol(handle.value)
sym.setAttr(AttrScope.current.get(attr))
sym.setAttr(AttrScope.current.get(Option(attr)))
sym
}

def FullyConnected: Map[String, Any] => Symbol = {
FullyConnected(null)
}

def FullyConnected(attr: Map[String, String]): Map[String, Any] => Symbol = {
createNoCheck("FullyConnected", attr)
}

def Activation: Map[String, Any] => Symbol = {
Activation(null)
}

def Activation(attr: Map[String, String]): Map[String, Any] => Symbol = {
createNoCheck("Activation", attr)
}

/**
* Create a symbol that groups symbols together.
* @param symbols List of symbols to be grouped.
* @return The created group symbol.
*/
def Group(symbols: Symbol*): Symbol = {
val ihandles = symbols.map(_.handle).toArray
val handle = new SymbolHandleRef
checkCall(_LIB.mxSymbolCreateGroup(ihandles, handle))
new Symbol(handle.value)
}

// List and add all the atomic symbol functions to current module.
private def initSymbolModule(): Map[String, SymbolFunction] = {
val symbolList = ListBuffer.empty[SymbolHandle]
Expand Down Expand Up @@ -120,79 +188,113 @@ object Symbol {
/**
* Activation Operator of Neural Net.
* The parameters listed below can be passed in as keyword arguments.
* @param name Name of the resulting symbol.
* // TODO
* @param symbols Symbol parameters passed to create the resulting symbol
* @param paramKwargs Key-value parameters passed to create the resulting symbol
* @param attr Attributes set to the resulting symbol
* @return the resulting symbol
*/
private def creator(operator: String,
name: String,
attr: Map[String, String],
paramKwargs: Map[String, String],
symbols: Symbol*): Symbol = {
def create(operator: String,
symbols: Array[Symbol],
paramKwargs: Map[String, String],
attr: Map[String, String]): Symbol = {
val function = functions(operator)
require(function != null, s"invalid operator name $operator")

val params = if (paramKwargs == null) Map.empty[String, String] else paramKwargs
val addkeyVarNumArgs = (function.keyVarNumArgs != null
&& !function.keyVarNumArgs.isEmpty
&& !paramKwargs.contains(function.keyVarNumArgs))
&& !params.contains(function.keyVarNumArgs))

val paramKeys: Array[String] = (
if (addkeyVarNumArgs) Array[String](function.keyVarNumArgs)
else Array.empty[String]
) ++ paramKwargs.keys
) ++ (params - "name").keys
val paramVals: Array[String] = (
if (addkeyVarNumArgs) Array[String](symbols.length.toString)
else Array.empty[String]
) ++ paramKwargs.values
) ++ (params - "name").values

// create atomic symbol
val symHandle = new SymbolHandleRef
checkCall(_LIB.mxSymbolCreateAtomicSymbol(
function.handle, paramKeys, paramVals, symHandle))

val s = new Symbol(symHandle.value)
val attrAll = AttrScope.current.get(attr)
val attrAll = AttrScope.current.get(Option(attr))
s.setAttr(attrAll)
val hint = operator.toLowerCase
val managedName = NameManager.current.get(name, hint)
s.compose(managedName, symbols.toArray)
val managedName = NameManager.current.get(params.get("name"), hint)
s.compose(managedName, symbols)
s
}

private def creator(operator: String, symbols: Symbol*): Symbol = {
creator(operator, null, null, Map.empty[String, String], symbols:_*)
def create(operator: String, symbols: Symbol*): Symbol = {
create(operator, symbols.toArray, null, null)
}

private def creator(operator: String,
name: String,
attr: Map[String, String],
paramKwargs: Map[String, String],
symbols: Map[String, Symbol]): Symbol = {
/**
* Activation Operator of Neural Net.
* The parameters listed below can be passed in as keyword arguments.
* @param symbols Named symbol parameters passed to create the resulting symbol
* @param paramKwargs Key-value parameters passed to create the resulting symbol
* @param attr Attributes set to the resulting symbol
* @return the resulting symbol
*/
private def create(operator: String,
symbols: Map[String, Symbol],
paramKwargs: Map[String, String],
attr: Map[String, String]): Symbol = {
val function = functions(operator)
require(function != null, s"invalid operator name $operator")
require(function.keyVarNumArgs == null || function.keyVarNumArgs.isEmpty,
"This function support variable length of Symbol arguments.\n" +
"Please pass all the input Symbols via positional arguments instead of keyword arguments.")

val paramKeys = paramKwargs.keys.toArray
val paramVals = paramKwargs.values.toArray
val paramKeys =
if (paramKwargs == null) Array.empty[String]
else (paramKwargs - "name").keys.toArray
val paramVals =
if (paramKwargs == null) Array.empty[String]
else (paramKwargs - "name").values.toArray
val symHandle = new SymbolHandleRef
checkCall(_LIB.mxSymbolCreateAtomicSymbol(
function.handle, paramKeys, paramVals, symHandle))

val s = new Symbol(symHandle.value)
val attrAll = AttrScope.current.get(attr)
val attrAll = AttrScope.current.get(Option(attr))
s.setAttr(attrAll)
val hint = operator.toLowerCase
val managedName = NameManager.current.get(name, hint)
val managedName = NameManager.current.get(paramKwargs.get("name"), hint)
s.compose(managedName, symbols)
s
}

private def creator(operator: String, symbols: Map[String, Symbol]): Symbol = {
creator(operator, null, null, Map.empty[String, String], symbols)
def create(operator: String, symbols: Map[String, Symbol]): Symbol = {
create(operator, symbols, null, null)
}

def create(operator: String,
symbols: Map[String, Symbol],
paramKwargs: Map[String, String]): Symbol = {
create(operator, symbols, paramKwargs, null)
}

// a more friendly interface for creating symbols
// all values except symbols in kwargs will be cast to String using its toString() method
def createNoCheck(operator: String, attr: Map[String, String] = null)(
kwargs: Map[String, Any]): Symbol = {
val symbolArgs = kwargs.filter { case (key, value) =>
value.isInstanceOf[Symbol]
}.map { case (key, value) =>
(key, value.asInstanceOf[Symbol])
}
val strArgs = kwargs.filter { case (key, value) =>
!value.isInstanceOf[Symbol]
}.map { case (key, value) =>
(key, value.toString)
}
create(operator, symbolArgs, strArgs, attr)
}
}

private case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String)
26 changes: 21 additions & 5 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,26 @@ package ml.dmlc.mxnet
import org.scalatest.{BeforeAndAfterAll, FunSuite}

class SymbolSuite extends FunSuite with BeforeAndAfterAll {
test("plus") {
val sym1 = Symbol.Variable("data1")
val sym2 = Symbol.Variable("data2")
val symPlus = sym1 + sym2
// TODO: check result
test("symbol compose") {
val data = Symbol.Variable("data")

var net1 = Symbol.FullyConnected(Map("data" -> data, "name" -> "fc1", "num_hidden" -> 10))
net1 = Symbol.FullyConnected(Map("data" -> net1, "name" -> "fc2", "num_hidden" -> 100))
assert(net1.listArguments() ===
Array("data", "fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias"))

var net2 = Symbol.FullyConnected(Map("name" -> "fc3", "num_hidden" -> 10))
net2 = Symbol.Activation(Map("data" -> net2, "act_type" -> "relu"))
net2 = Symbol.FullyConnected(Map("data" -> net2, "name" -> "fc4", "num_hidden" -> 20))
// scalastyle:off println
println(s"net2 debug info:\n${net2.debugStr}")
// scalastyle:on println

val composed = net2(name = "composed", Map("fc3_data" -> net1))
// scalastyle:off println
println(s"composed debug info:\n${composed.debugStr}")
// scalastyle:on println
val multiOut = Symbol.Group(composed, net1)
assert(multiOut.listOutputs().length === 2)
}
}
Loading

0 comments on commit e1aa386

Please sign in to comment.