From a41016e749441b21a2e7c79e70203c9251ec6952 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 3 Jan 2016 00:01:12 +0800 Subject: [PATCH] Symbol plus tested, but not assert final result --- .../main/scala/ml/dmlc/mxnet/AttrScope.scala | 20 +-- .../scala/ml/dmlc/mxnet/NameManager.scala | 19 +- .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 170 +++++++++++------- .../scala/ml/dmlc/mxnet/SymbolSuite.scala | 12 ++ .../main/native/ml_dmlc_mxnet_native_c_api.cc | 14 +- 5 files changed, 149 insertions(+), 86 deletions(-) create mode 100644 scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala index 5bead65c379b..03c6bba384c9 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala @@ -1,15 +1,5 @@ package ml.dmlc.mxnet -object AttrScope { - private var _current = new AttrScope() - def current: AttrScope = _current - private def setCurrentAttr(attr: AttrScope): Unit = { - _current = attr - } - - def apply(attr: Map[String, String] = Map.empty): AttrScope = new AttrScope(attr) -} - /** * Attribute manager for scoping. * User can also inherit this object to change naming behavior. @@ -41,3 +31,13 @@ class AttrScope(attr: Map[String, String] = Map.empty) { } } } + +object AttrScope { + private var _current = new AttrScope() + def current: AttrScope = _current + private def setCurrentAttr(attr: AttrScope): Unit = { + _current = attr + } + + def apply(attr: Map[String, String] = Map.empty): AttrScope = new AttrScope(attr) +} diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala index 3fff60b80a13..3cb536ab264f 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala @@ -11,16 +11,13 @@ class NameManager { val counter: mutable.Map[String, Int] = mutable.HashMap.empty[String, Int] /** * Get the canonical name for a symbol. - * This is default implementation. - * When user specified a name, - * the user specified name will be used. + * This is default implementation. + * When user specified a name, + * the user specified name will be used. + * When user did not, we will automatically generate a name based on hint string. * - * When user did not, we will automatically generate a - * name based on hint string. - * @param name : str or None - The name user specified. - * @param hint : str - A hint string, which can be used to generate name. + * @param name : The name user specified. + * @param hint : A hint string, which can be used to generate name. * @return A canonical name for the user. */ def get(name: String, hint: String): String = { @@ -30,9 +27,9 @@ class NameManager { if (!counter.contains(hint)) { counter(hint) = 0 } - val name = s"$hint${counter(hint)}" + val generatedName = s"$hint${counter(hint)}" counter(hint) += 1 - name + generatedName } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 4bbe3e1c8049..7a8699ee8a25 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -5,6 +5,75 @@ import org.slf4j.LoggerFactory import scala.collection.mutable.ListBuffer +/** + * Symbolic configuration API of mxnet. + * @author Yizhi Liu + */ +class Symbol(private[mxnet] val handle: SymbolHandle) { + def +(other: Symbol): Symbol = Symbol.creator("_Plus", other) + def +(other: Int): Symbol = ??? + def +(other: Float): Symbol = ??? + def +(other: Double): Symbol = ??? + + /** + * List all the arguments in the symbol. + * @return Array of all the arguments. + */ + def listArguments(): Array[String] = ??? + + /** + * List all auxiliary states in the symbol. + * @return The names of the auxiliary states. + * Notes + * ----- + * Auxiliary states are special states of symbols that do not corresponds to an argument, + * and do not have gradient. But still be useful for the specific operations. + * A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm. + * Most operators do not have Auxiliary states. + */ + def listAuxiliaryStates(): Array[String] = ??? + + /** + * Get attribute string from the symbol, this function only works for non-grouped symbol. + * @param key The key to get attribute from. + * @return value The attribute value of the key, returns None if attribute do not exist. + */ + def attr(key: String): Option[String] = { + val ret = new RefString + val success = new RefInt + checkCall(_LIB.mxSymbolGetAttr(handle, key, ret, success)) + if (success.value != 0) { + Option(ret.value) + } else { + None + } + } + + // Set the attribute of the symbol. + private def setAttr(attr: Map[String, String]): Unit = { + attr.foreach { case (key, value) => + checkCall(_LIB.mxSymbolSetAttr(handle, key, value)) + } + } + + /** + * Compose symbol on inputs. + * This call mutates the current symbol. + * @param symbols provide positional arguments + * @return the resulting symbol + */ + private def compose(name: String, symbols: Array[Symbol]): Unit = { + val args = symbols.map(_.handle) + checkCall(_LIB.mxSymbolCompose(handle, name, null, args)) + } + + private def compose(name: String, symbols: Map[String, Symbol]): Unit = { + val keys = symbols.keys.toArray + val args = symbols.values.map(_.handle).toArray + checkCall(_LIB.mxSymbolCompose(handle, name, null, args)) + } +} + object Symbol { private val logger = LoggerFactory.getLogger(classOf[Symbol]) private val functions: Map[String, SymbolFunction] = initSymbolModule() @@ -12,8 +81,7 @@ object Symbol { /** * Create a symbolic variable with specified name. * @param name Name of the variable. - * @param attr dict of string -> string - Additional attributes to set on the variable. + * @param attr Additional attributes to set on the variable. * @return The created variable symbol. */ def Variable(name: String, attr: Map[String, String] = null): Symbol = { @@ -56,16 +124,17 @@ object Symbol { * // TODO * @return the resulting symbol */ - def creator(operator: String, - name: String, - attr: Map[String, String], - paramKwargs: Map[String, String], - symbols: Symbol*): Symbol = { + private def creator(operator: String, + name: String, + attr: Map[String, String], + paramKwargs: Map[String, String], + symbols: Symbol*): Symbol = { val function = functions(operator) require(function != null, s"invalid operator name $operator") - val addkeyVarNumArgs = - (function.keyVarNumArgs != null) && !paramKwargs.contains(function.keyVarNumArgs) + val addkeyVarNumArgs = (function.keyVarNumArgs != null + && !function.keyVarNumArgs.isEmpty + && !paramKwargs.contains(function.keyVarNumArgs)) val paramKeys: Array[String] = ( if (addkeyVarNumArgs) Array[String](function.keyVarNumArgs) @@ -86,69 +155,44 @@ object Symbol { s.setAttr(attrAll) val hint = operator.toLowerCase val managedName = NameManager.current.get(name, hint) - s.compose(name = managedName, symbols.toArray) + s.compose(managedName, symbols.toArray) s } -} -class Symbol(private[mxnet] val handle: SymbolHandle) { - def +(other: Symbol): Symbol = ??? - def +(other: Int): Symbol = ??? - def +(other: Float): Symbol = ??? - def +(other: Double): Symbol = ??? - - /** - * List all the arguments in the symbol. - * @return Array of all the arguments. - */ - def listArguments(): Array[String] = ??? + private def creator(operator: String, symbols: Symbol*): Symbol = { + creator(operator, null, null, Map.empty[String, String], symbols:_*) + } - /** - * List all auxiliary states in the symbol. - * @return The names of the auxiliary states. - * Notes - * ----- - * Auxiliary states are special states of symbols that do not corresponds to an argument, - * and do not have gradient. But still be useful for the specific operations. - * A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm. - * Most operators do not have Auxiliary states. - */ - def listAuxiliaryStates(): Array[String] = ??? + private def creator(operator: String, + name: String, + attr: Map[String, String], + paramKwargs: Map[String, String], + symbols: Map[String, Symbol]): Symbol = { + val function = functions(operator) + require(function != null, s"invalid operator name $operator") + require(function.keyVarNumArgs == null || function.keyVarNumArgs.isEmpty, + "This function support variable length of Symbol arguments.\n" + + "Please pass all the input Symbols via positional arguments instead of keyword arguments.") - /** - * Get attribute string from the symbol, this function only works for non-grouped symbol. - * @param key The key to get attribute from. - * @return value The attribute value of the key, returns None if attribute do not exist. - */ - def attr(key: String): Option[String] = { - val ret = new RefString - val success = new RefInt - checkCall(_LIB.mxSymbolGetAttr(handle, key, ret, success)) - if (success.value != 0) { - Option(ret.value) - } else { - None - } - } + val paramKeys = paramKwargs.keys.toArray + val paramVals = paramKwargs.values.toArray + val symHandle = new SymbolHandleRef + checkCall(_LIB.mxSymbolCreateAtomicSymbol( + function.handle, paramKeys, paramVals, symHandle)) - // Set the attribute of the symbol. - private def setAttr(attr: Map[String, String]): Unit = { - attr.foreach { case (key, value) => - checkCall(_LIB.mxSymbolSetAttr(handle, key, value)) - } + val s = new Symbol(symHandle.value) + val attrAll = AttrScope.current.get(attr) + s.setAttr(attrAll) + val hint = operator.toLowerCase + val managedName = NameManager.current.get(name, hint) + s.compose(managedName, symbols) + s } - /** - * Compose symbol on inputs. - * This call mutates the current symbol. - * @param symbols provide positional arguments - * @return the resulting symbol - */ - private def compose(name: String, symbols: Array[Symbol]): Unit = { - val args = symbols.map(_.handle) - checkCall(_LIB.mxSymbolCompose(handle, name, null, args)) + private def creator(operator: String, symbols: Map[String, Symbol]): Symbol = { + creator(operator, null, null, Map.empty[String, String], symbols) } } -case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String) +private case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala new file mode 100644 index 000000000000..6a3b02a23d60 --- /dev/null +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala @@ -0,0 +1,12 @@ +package ml.dmlc.mxnet + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class SymbolSuite extends FunSuite with BeforeAndAfterAll { + test("plus") { + val sym1 = Symbol.Variable("data1") + val sym2 = Symbol.Variable("data2") + val symPlus = sym1 + sym2 + // TODO: check result + } +} diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index fb2758e3cb6e..508f26f8378c 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -780,15 +780,25 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCompose int argSize = env->GetArrayLength(jargs); const char **keys = NULL; if (jkeys != NULL) { - // TODO + keys = new const char*[argSize]; + for (int i = 0; i < argSize; i++) { + jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); + const char *key = env->GetStringUTFChars(jkey, 0); + keys[i] = key; + } } jlong *args = env->GetLongArrayElements(jargs, NULL); const char *name = env->GetStringUTFChars(jname, 0); int ret = MXSymbolCompose((SymbolHandle) symbolPtr, name, (mx_uint) argSize, keys, (SymbolHandle*) args); + // release allocated memory if (jkeys != NULL) { - // TODO + for (int i = 0; i < argSize; i++) { + jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); + env->ReleaseStringUTFChars(jkey, keys[i]); + } + delete[] keys; } env->ReleaseStringUTFChars(jname, name); env->ReleaseLongArrayElements(jargs, args, 0);