Skip to content

Commit

Permalink
Symbol plus tested, but not assert final result
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu committed Jan 2, 2016
1 parent 12b11a7 commit a41016e
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 86 deletions.
20 changes: 10 additions & 10 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
package ml.dmlc.mxnet

object AttrScope {
private var _current = new AttrScope()
def current: AttrScope = _current
private def setCurrentAttr(attr: AttrScope): Unit = {
_current = attr
}

def apply(attr: Map[String, String] = Map.empty): AttrScope = new AttrScope(attr)
}

/**
* Attribute manager for scoping.
* User can also inherit this object to change naming behavior.
Expand Down Expand Up @@ -41,3 +31,13 @@ class AttrScope(attr: Map[String, String] = Map.empty) {
}
}
}

object AttrScope {
private var _current = new AttrScope()
def current: AttrScope = _current
private def setCurrentAttr(attr: AttrScope): Unit = {
_current = attr
}

def apply(attr: Map[String, String] = Map.empty): AttrScope = new AttrScope(attr)
}
19 changes: 8 additions & 11 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,13 @@ class NameManager {
val counter: mutable.Map[String, Int] = mutable.HashMap.empty[String, Int]
/**
* Get the canonical name for a symbol.
* This is default implementation.
* When user specified a name,
* the user specified name will be used.
* This is default implementation.
* When user specified a name,
* the user specified name will be used.
* When user did not, we will automatically generate a name based on hint string.
*
* When user did not, we will automatically generate a
* name based on hint string.
* @param name : str or None
The name user specified.
* @param hint : str
A hint string, which can be used to generate name.
* @param name : The name user specified.
* @param hint : A hint string, which can be used to generate name.
* @return A canonical name for the user.
*/
def get(name: String, hint: String): String = {
Expand All @@ -30,9 +27,9 @@ class NameManager {
if (!counter.contains(hint)) {
counter(hint) = 0
}
val name = s"$hint${counter(hint)}"
val generatedName = s"$hint${counter(hint)}"
counter(hint) += 1
name
generatedName
}
}

Expand Down
170 changes: 107 additions & 63 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,83 @@ import org.slf4j.LoggerFactory

import scala.collection.mutable.ListBuffer

/**
* Symbolic configuration API of mxnet.
* @author Yizhi Liu
*/
class Symbol(private[mxnet] val handle: SymbolHandle) {
def +(other: Symbol): Symbol = Symbol.creator("_Plus", other)
def +(other: Int): Symbol = ???
def +(other: Float): Symbol = ???
def +(other: Double): Symbol = ???

/**
* List all the arguments in the symbol.
* @return Array of all the arguments.
*/
def listArguments(): Array[String] = ???

/**
* List all auxiliary states in the symbol.
* @return The names of the auxiliary states.
* Notes
* -----
* Auxiliary states are special states of symbols that do not corresponds to an argument,
* and do not have gradient. But still be useful for the specific operations.
* A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm.
* Most operators do not have Auxiliary states.
*/
def listAuxiliaryStates(): Array[String] = ???

/**
* Get attribute string from the symbol, this function only works for non-grouped symbol.
* @param key The key to get attribute from.
* @return value The attribute value of the key, returns None if attribute do not exist.
*/
def attr(key: String): Option[String] = {
val ret = new RefString
val success = new RefInt
checkCall(_LIB.mxSymbolGetAttr(handle, key, ret, success))
if (success.value != 0) {
Option(ret.value)
} else {
None
}
}

// Set the attribute of the symbol.
private def setAttr(attr: Map[String, String]): Unit = {
attr.foreach { case (key, value) =>
checkCall(_LIB.mxSymbolSetAttr(handle, key, value))
}
}

/**
* Compose symbol on inputs.
* This call mutates the current symbol.
* @param symbols provide positional arguments
* @return the resulting symbol
*/
private def compose(name: String, symbols: Array[Symbol]): Unit = {
val args = symbols.map(_.handle)
checkCall(_LIB.mxSymbolCompose(handle, name, null, args))
}

private def compose(name: String, symbols: Map[String, Symbol]): Unit = {
val keys = symbols.keys.toArray
val args = symbols.values.map(_.handle).toArray
checkCall(_LIB.mxSymbolCompose(handle, name, null, args))
}
}

object Symbol {
private val logger = LoggerFactory.getLogger(classOf[Symbol])
private val functions: Map[String, SymbolFunction] = initSymbolModule()

/**
* Create a symbolic variable with specified name.
* @param name Name of the variable.
* @param attr dict of string -> string
Additional attributes to set on the variable.
* @param attr Additional attributes to set on the variable.
* @return The created variable symbol.
*/
def Variable(name: String, attr: Map[String, String] = null): Symbol = {
Expand Down Expand Up @@ -56,16 +124,17 @@ object Symbol {
* // TODO
* @return the resulting symbol
*/
def creator(operator: String,
name: String,
attr: Map[String, String],
paramKwargs: Map[String, String],
symbols: Symbol*): Symbol = {
private def creator(operator: String,
name: String,
attr: Map[String, String],
paramKwargs: Map[String, String],
symbols: Symbol*): Symbol = {
val function = functions(operator)
require(function != null, s"invalid operator name $operator")

val addkeyVarNumArgs =
(function.keyVarNumArgs != null) && !paramKwargs.contains(function.keyVarNumArgs)
val addkeyVarNumArgs = (function.keyVarNumArgs != null
&& !function.keyVarNumArgs.isEmpty
&& !paramKwargs.contains(function.keyVarNumArgs))

val paramKeys: Array[String] = (
if (addkeyVarNumArgs) Array[String](function.keyVarNumArgs)
Expand All @@ -86,69 +155,44 @@ object Symbol {
s.setAttr(attrAll)
val hint = operator.toLowerCase
val managedName = NameManager.current.get(name, hint)
s.compose(name = managedName, symbols.toArray)
s.compose(managedName, symbols.toArray)
s
}
}

class Symbol(private[mxnet] val handle: SymbolHandle) {
def +(other: Symbol): Symbol = ???
def +(other: Int): Symbol = ???
def +(other: Float): Symbol = ???
def +(other: Double): Symbol = ???

/**
* List all the arguments in the symbol.
* @return Array of all the arguments.
*/
def listArguments(): Array[String] = ???
private def creator(operator: String, symbols: Symbol*): Symbol = {
creator(operator, null, null, Map.empty[String, String], symbols:_*)
}

/**
* List all auxiliary states in the symbol.
* @return The names of the auxiliary states.
* Notes
* -----
* Auxiliary states are special states of symbols that do not corresponds to an argument,
* and do not have gradient. But still be useful for the specific operations.
* A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm.
* Most operators do not have Auxiliary states.
*/
def listAuxiliaryStates(): Array[String] = ???
private def creator(operator: String,
name: String,
attr: Map[String, String],
paramKwargs: Map[String, String],
symbols: Map[String, Symbol]): Symbol = {
val function = functions(operator)
require(function != null, s"invalid operator name $operator")
require(function.keyVarNumArgs == null || function.keyVarNumArgs.isEmpty,
"This function support variable length of Symbol arguments.\n" +
"Please pass all the input Symbols via positional arguments instead of keyword arguments.")

/**
* Get attribute string from the symbol, this function only works for non-grouped symbol.
* @param key The key to get attribute from.
* @return value The attribute value of the key, returns None if attribute do not exist.
*/
def attr(key: String): Option[String] = {
val ret = new RefString
val success = new RefInt
checkCall(_LIB.mxSymbolGetAttr(handle, key, ret, success))
if (success.value != 0) {
Option(ret.value)
} else {
None
}
}
val paramKeys = paramKwargs.keys.toArray
val paramVals = paramKwargs.values.toArray
val symHandle = new SymbolHandleRef
checkCall(_LIB.mxSymbolCreateAtomicSymbol(
function.handle, paramKeys, paramVals, symHandle))

// Set the attribute of the symbol.
private def setAttr(attr: Map[String, String]): Unit = {
attr.foreach { case (key, value) =>
checkCall(_LIB.mxSymbolSetAttr(handle, key, value))
}
val s = new Symbol(symHandle.value)
val attrAll = AttrScope.current.get(attr)
s.setAttr(attrAll)
val hint = operator.toLowerCase
val managedName = NameManager.current.get(name, hint)
s.compose(managedName, symbols)
s
}

/**
* Compose symbol on inputs.
* This call mutates the current symbol.
* @param symbols provide positional arguments
* @return the resulting symbol
*/
private def compose(name: String, symbols: Array[Symbol]): Unit = {
val args = symbols.map(_.handle)
checkCall(_LIB.mxSymbolCompose(handle, name, null, args))
private def creator(operator: String, symbols: Map[String, Symbol]): Symbol = {
creator(operator, null, null, Map.empty[String, String], symbols)
}

}

case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String)
private case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String)
12 changes: 12 additions & 0 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package ml.dmlc.mxnet

import org.scalatest.{BeforeAndAfterAll, FunSuite}

class SymbolSuite extends FunSuite with BeforeAndAfterAll {
test("plus") {
val sym1 = Symbol.Variable("data1")
val sym2 = Symbol.Variable("data2")
val symPlus = sym1 + sym2
// TODO: check result
}
}
14 changes: 12 additions & 2 deletions scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -780,15 +780,25 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCompose
int argSize = env->GetArrayLength(jargs);
const char **keys = NULL;
if (jkeys != NULL) {
// TODO
keys = new const char*[argSize];
for (int i = 0; i < argSize; i++) {
jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
const char *key = env->GetStringUTFChars(jkey, 0);
keys[i] = key;
}
}
jlong *args = env->GetLongArrayElements(jargs, NULL);
const char *name = env->GetStringUTFChars(jname, 0);
int ret = MXSymbolCompose((SymbolHandle) symbolPtr,
name, (mx_uint) argSize, keys,
(SymbolHandle*) args);
// release allocated memory
if (jkeys != NULL) {
// TODO
for (int i = 0; i < argSize; i++) {
jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i);
env->ReleaseStringUTFChars(jkey, keys[i]);
}
delete[] keys;
}
env->ReleaseStringUTFChars(jname, name);
env->ReleaseLongArrayElements(jargs, args, 0);
Expand Down

0 comments on commit a41016e

Please sign in to comment.