From 930cd91b0d1320a1a1ddc1c5a7f699e5b77e7cc9 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 31 Dec 2015 23:37:38 +0800 Subject: [PATCH 1/6] Symbol functions list, creator not implemented yet --- .../src/main/scala/ml/dmlc/mxnet/Base.scala | 4 +- .../src/main/scala/ml/dmlc/mxnet/IO.scala | 2 +- .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 11 ++++ .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 44 ++++++++++++- .../main/native/ml_dmlc_mxnet_native_c_api.cc | 63 +++++++++++++++++++ 5 files changed, 120 insertions(+), 4 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index 0d5c848afd79..feff8441acb1 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -11,6 +11,8 @@ object Base { type MXFloat = Float type CPtrAddress = Long + type SymbolHandle = CPtrAddress + type MXUintRef = RefInt type MXFloatRef = RefFloat type NDArrayHandle = RefLong @@ -20,11 +22,9 @@ object Base { type KVStoreHandle = RefLong type ExecutorHandle = RefLong - System.loadLibrary("mxnet-scala") val _LIB = new LibInfo - // helper function definitions /** * Check the return value of C API call diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index 9b999e0d6b32..2a706703aecd 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -14,7 +14,7 @@ object IO { /** * create iterator via iterName and params * @param iterName name of iterator; "MNISTIter" or "ImageRecordIter" - * @param params paramters for create iterator + * @param params parameters for create iterator * @return */ def createIterator(iterName: String, params: Map[String, String]): DataIter = { diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 79e84b598c15..decb4c238820 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -120,4 +120,15 @@ class LibInfo { grads: Array[CPtrAddress]): Int @native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int @native def mxExecutorSetMonitorCallback(handle: ExecutorHandle, callback: MXMonitorCallback): Int + + // Symbols + @native def mxSymbolListAtomicSymbolCreators(symbolList: ListBuffer[SymbolHandle]): Int + @native def mxSymbolGetAtomicSymbolInfo(handle: SymbolHandle, + name: RefString, + desc: RefString, + numArgs: MXUintRef, + argNames: ListBuffer[String], + argTypes: ListBuffer[String], + argDescs: ListBuffer[String], + keyVarNumArgs: RefString): Int } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 3f70f2764c2a..18ebb6d114bb 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -1,6 +1,46 @@ package ml.dmlc.mxnet -class Symbol { +import ml.dmlc.mxnet.Base._ +import org.slf4j.LoggerFactory + +import scala.collection.mutable.ListBuffer + +object Symbol { + private val logger = LoggerFactory.getLogger(classOf[Symbol]) + private val functions: Map[String, SymbolFunction] = initSymbolModule() + + // List and add all the atomic symbol functions to current module. + private def initSymbolModule(): Map[String, SymbolFunction] = { + val symbolList = ListBuffer.empty[SymbolHandle] + checkCall(_LIB.mxSymbolListAtomicSymbolCreators(symbolList)) + symbolList.map(makeAtomicSymbolFunction).toMap + } + + // Create an atomic symbol function by handle and funciton name. + private def makeAtomicSymbolFunction(handle: SymbolHandle): (String, SymbolFunction) = { + val name = new RefString + val desc = new RefString + val keyVarNumArgs = new RefString + val numArgs = new MXUintRef + val argNames = ListBuffer.empty[String] + val argTypes = ListBuffer.empty[String] + val argDescs = ListBuffer.empty[String] + + checkCall(_LIB.mxSymbolGetAtomicSymbolInfo( + handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)) + val paramStr = ctypes2docstring(argNames, argTypes, argDescs) + val docStr = s"${name.value}\n${desc.value}\n\n$paramStr\n" + logger.debug("Atomic Symbol function defination:\n{}", docStr) + (name.value, new SymbolFunction(handle, keyVarNumArgs.value)) + } +} + +class Symbol(private[mxnet] val handle: SymbolHandle) { + def +(other: Symbol): Symbol = ??? + def +(other: Int): Symbol = ??? + def +(other: Float): Symbol = ??? + def +(other: Double): Symbol = ??? + /** * List all the arguments in the symbol. * @return Array of all the arguments. @@ -19,3 +59,5 @@ class Symbol { */ def listAuxiliaryStates(): Array[String] = ??? } + +case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String) diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index ade79168a9a7..cb12b3c0f5f0 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -666,3 +666,66 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetPadNum setIntField(env, pad, cpad); return ret; } + +// Symbol functions +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListAtomicSymbolCreators + (JNIEnv *env, jobject obj, jobject symbolList) { + mx_uint outSize; + AtomicSymbolCreator *outArray; + int ret = MXSymbolListAtomicSymbolCreators(&outSize, &outArray); + + jclass longCls = env->FindClass("java/lang/Long"); + jmethodID longConst = env->GetMethodID(longCls, "", "(J)V"); + + jclass listCls = env->FindClass("scala/collection/mutable/ListBuffer"); + jmethodID listAppend = env->GetMethodID(listCls, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); + + for (int i = 0; i < outSize; ++i) { + env->CallObjectMethod(symbolList, listAppend, + env->NewObject(longCls, longConst, outArray[i])); + } + + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolGetAtomicSymbolInfo + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject name, jobject desc, jobject numArgs, + jobject argNames, jobject argTypes, jobject argDescs, jobject keyVarNumArgs) { + + const char *cName; + const char *cDesc; + mx_uint cNumArgs; + const char **cArgNames; + const char **cArgTypes; + const char **cArgDescs; + const char *cKeyVarNumArgs; + + int ret = MXSymbolGetAtomicSymbolInfo((AtomicSymbolCreator) symbolPtr, + &cName, &cDesc, &cNumArgs, + &cArgNames, &cArgTypes, &cArgDescs, + &cKeyVarNumArgs); + + jclass refIntClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt"); + jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I"); + + jclass refStringClass = env->FindClass("ml/dmlc/mxnet/Base$RefString"); + jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;"); + + // scala.collection.mutable.ListBuffer append method + jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); + jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq", + "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); + + env->SetObjectField(name, valueStr, env->NewStringUTF(cName)); + env->SetObjectField(desc, valueStr, env->NewStringUTF(cDesc)); + env->SetObjectField(keyVarNumArgs, valueStr, env->NewStringUTF(cKeyVarNumArgs)); + env->SetIntField(numArgs, valueInt, (jint)cNumArgs); + for (int i = 0; i < cNumArgs; ++i) { + env->CallObjectMethod(argNames, listAppend, env->NewStringUTF(cArgNames[i])); + env->CallObjectMethod(argTypes, listAppend, env->NewStringUTF(cArgTypes[i])); + env->CallObjectMethod(argDescs, listAppend, env->NewStringUTF(cArgDescs[i])); + } + + return ret; +} From 88ab916a32bccdad034df125bbe16c1307070a80 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 1 Jan 2016 23:01:11 +0800 Subject: [PATCH 2/6] Symbol function build with attr scope --- .../main/scala/ml/dmlc/mxnet/AttrScope.scala | 34 ++++++++++++ .../src/main/scala/ml/dmlc/mxnet/Base.scala | 1 + .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 5 ++ .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 54 ++++++++++++++++++- .../main/native/ml_dmlc_mxnet_native_c_api.cc | 44 +++++++++++++++ 5 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala new file mode 100644 index 000000000000..689cb71b10a7 --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala @@ -0,0 +1,34 @@ +package ml.dmlc.mxnet + +object AttrScope { + private var _current = new AttrScope() + def current: AttrScope = _current + private def setCurrentAttr(attr: AttrScope): Unit = { + _current = attr + } + + def withScope[T](attr: Map[String, String])(body: => T): T = { + val oldAttrScope = AttrScope.current + val updatedAttr = AttrScope.current.attr ++ attr + AttrScope.setCurrentAttr(new AttrScope(updatedAttr)) + val ret = body + AttrScope.setCurrentAttr(oldAttrScope) + ret + } +} + +/** + * Attribute manager for scoping. + * User can also inherit this object to change naming behavior. + * @author Yizhi Liu + */ +class AttrScope(private var attr: Map[String, String] = Map.empty) { + /** + * Get the attribute dict given the attribute set by the symbol. + * @param userDefinedAttr The attribute passed in by user during symbol creation. + * @return Updated attributes to add other scope related attributes. + */ + def get(userDefinedAttr: Map[String, String]): Map[String, String] = { + attr ++ userDefinedAttr + } +} diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index feff8441acb1..26026fc4d2eb 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -21,6 +21,7 @@ object Base { type DataIterCreator = RefLong type KVStoreHandle = RefLong type ExecutorHandle = RefLong + type SymbolHandleRef = RefLong System.loadLibrary("mxnet-scala") val _LIB = new LibInfo diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index decb4c238820..653520480efd 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -131,4 +131,9 @@ class LibInfo { argTypes: ListBuffer[String], argDescs: ListBuffer[String], keyVarNumArgs: RefString): Int + @native def mxSymbolCreateAtomicSymbol(handle: SymbolHandle, + paramKeys: Array[String], + paramVals: Array[String], + symHandleRef: SymbolHandleRef): Int + @native def mxSymbolSetAttr(handle: SymbolHandle, key: String, value: String): Int } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 18ebb6d114bb..6a160e9d184e 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -16,7 +16,7 @@ object Symbol { symbolList.map(makeAtomicSymbolFunction).toMap } - // Create an atomic symbol function by handle and funciton name. + // Create an atomic symbol function by handle and function name. private def makeAtomicSymbolFunction(handle: SymbolHandle): (String, SymbolFunction) = { val name = new RefString val desc = new RefString @@ -33,6 +33,49 @@ object Symbol { logger.debug("Atomic Symbol function defination:\n{}", docStr) (name.value, new SymbolFunction(handle, keyVarNumArgs.value)) } + + /** + * Activation Operator of Neural Net. + * The parameters listed below can be passed in as keyword arguments. + * @param name Name of the resulting symbol. + * // TODO + * @return the resulting symbol + */ + def creator(operator: String, + name: String, + attr: Map[String, String], + paramKwargs: Map[String, String], + symbols: Symbol*): Symbol = { + val function = functions(operator) + require(function != null, s"invalid operator name $operator") + + val addkeyVarNumArgs = + (function.keyVarNumArgs != null) && !paramKwargs.contains(function.keyVarNumArgs) + + val paramKeys: Array[String] = ( + if (addkeyVarNumArgs) Array[String](function.keyVarNumArgs) + else Array.empty[String] + ) ++ paramKwargs.keys + val paramVals: Array[String] = ( + if (addkeyVarNumArgs) Array[String](symbols.length.toString) + else Array.empty[String] + ) ++ paramKwargs.values + + // create atomic symbol + val symHandle = new SymbolHandleRef + checkCall(_LIB.mxSymbolCreateAtomicSymbol( + function.handle, paramKeys, paramVals, symHandle)) + + val s = new Symbol(symHandle.value) + val attrAll = AttrScope.current.get(attr) + s.setAttr(attrAll) + val hint = operator.toLowerCase + /* TODO + name = NameManager.current.get(name, hint) + s._compose(*args, name = name, **symbol_kwargs) + */ + s + } } class Symbol(private[mxnet] val handle: SymbolHandle) { @@ -48,7 +91,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { def listArguments(): Array[String] = ??? /** - * List all auxiliary states in the symbool. + * List all auxiliary states in the symbol. * @return The names of the auxiliary states. * Notes * ----- @@ -58,6 +101,13 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { * Most operators do not have Auxiliary states. */ def listAuxiliaryStates(): Array[String] = ??? + + // Set the attribute of the symbol. + private def setAttr(attr: Map[String, String]): Unit = { + attr.foreach { case (key, value) => + checkCall(_LIB.mxSymbolSetAttr(handle, key, value)) + } + } } case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String) diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index cb12b3c0f5f0..e3c71cb1c82a 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -729,3 +729,47 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolGetAtomicSymbolInfo return ret; } + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCreateAtomicSymbol + (JNIEnv *env, jobject obj, jlong symbolPtr, jobjectArray paramKeys, + jobjectArray paramVals, jobject symbolRef) { + int paramSize = env->GetArrayLength(paramKeys); + const char **keys = new const char*[paramSize]; + const char **vals = new const char*[paramSize]; + for (int i = 0; i < paramSize; i++) { + jstring key = (jstring) env->GetObjectArrayElement(paramKeys, i); + const char *rawKey = env->GetStringUTFChars(key, 0); + keys[i] = rawKey; + + jstring value = (jstring) env->GetObjectArrayElement(paramVals, i); + const char *rawValue = env->GetStringUTFChars(value, 0); + vals[i] = rawValue; + } + + SymbolHandle out; + int ret = MXSymbolCreateAtomicSymbol( + (AtomicSymbolCreator) symbolPtr, (mx_uint) paramSize, keys, vals, &out); + setLongField(env, symbolRef, (jlong) out); + + // release keys and vals + for (int i = 0; i < paramSize; i++) { + jstring key = (jstring) env->GetObjectArrayElement(paramKeys, i); + env->ReleaseStringUTFChars(key, keys[i]); + jstring value = (jstring) env->GetObjectArrayElement(paramVals, i); + env->ReleaseStringUTFChars(value, vals[i]); + } + delete[] keys; + delete[] vals; + + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolSetAttr + (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jkey, jstring jvalue) { + const char *ckey = env->GetStringUTFChars(jkey, 0); + const char *cvalue = env->GetStringUTFChars(jvalue, 0); + int ret = MXSymbolSetAttr((SymbolHandle) symbolPtr, ckey, cvalue); + env->ReleaseStringUTFChars(jkey, ckey); + env->ReleaseStringUTFChars(jvalue, cvalue); + return ret; +} From ec162771e7f6e2033432c6d6829aba76adf5dbdd Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 1 Jan 2016 23:50:15 +0800 Subject: [PATCH 3/6] NameManager. AttrScope withScope move from object to class --- .../main/scala/ml/dmlc/mxnet/AttrScope.scala | 23 ++++---- .../scala/ml/dmlc/mxnet/NameManager.scala | 56 +++++++++++++++++++ .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 2 +- 3 files changed, 70 insertions(+), 11 deletions(-) create mode 100644 scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala index 689cb71b10a7..8631849f6ccf 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala @@ -6,15 +6,6 @@ object AttrScope { private def setCurrentAttr(attr: AttrScope): Unit = { _current = attr } - - def withScope[T](attr: Map[String, String])(body: => T): T = { - val oldAttrScope = AttrScope.current - val updatedAttr = AttrScope.current.attr ++ attr - AttrScope.setCurrentAttr(new AttrScope(updatedAttr)) - val ret = body - AttrScope.setCurrentAttr(oldAttrScope) - ret - } } /** @@ -22,7 +13,8 @@ object AttrScope { * User can also inherit this object to change naming behavior. * @author Yizhi Liu */ -class AttrScope(private var attr: Map[String, String] = Map.empty) { +class AttrScope(val attr: Map[String, String] = Map.empty) { + private var _attr = attr /** * Get the attribute dict given the attribute set by the symbol. * @param userDefinedAttr The attribute passed in by user during symbol creation. @@ -31,4 +23,15 @@ class AttrScope(private var attr: Map[String, String] = Map.empty) { def get(userDefinedAttr: Map[String, String]): Map[String, String] = { attr ++ userDefinedAttr } + + def withScope[T](body: => T): T = { + val oldAttrScope = AttrScope.current + this._attr = AttrScope.current.attr ++ this._attr + AttrScope.setCurrentAttr(this) + try { + body + } finally { + AttrScope.setCurrentAttr(oldAttrScope) + } + } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala new file mode 100644 index 000000000000..3fff60b80a13 --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala @@ -0,0 +1,56 @@ +package ml.dmlc.mxnet + +import scala.collection.mutable + +/** + * NameManager to do automatic naming. + * User can also inherit this object to change naming behavior. + * @author Yizhi Liu + */ +class NameManager { + val counter: mutable.Map[String, Int] = mutable.HashMap.empty[String, Int] + /** + * Get the canonical name for a symbol. + * This is default implementation. + * When user specified a name, + * the user specified name will be used. + * + * When user did not, we will automatically generate a + * name based on hint string. + * @param name : str or None + The name user specified. + * @param hint : str + A hint string, which can be used to generate name. + * @return A canonical name for the user. + */ + def get(name: String, hint: String): String = { + if (name != null) { + name + } else { + if (!counter.contains(hint)) { + counter(hint) = 0 + } + val name = s"$hint${counter(hint)}" + counter(hint) += 1 + name + } + } + + def withScope[T](body: => T): T = { + val oldManager = NameManager.current + NameManager.setCurrentManager(this) + try { + body + } finally { + NameManager.setCurrentManager(oldManager) + } + } +} + +object NameManager { + private var _current = new NameManager() + def current: NameManager = _current + private def setCurrentManager(manager: NameManager): Unit = { + _current = manager + } +} diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 6a160e9d184e..b4ea5bb9d8cf 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -70,8 +70,8 @@ object Symbol { val attrAll = AttrScope.current.get(attr) s.setAttr(attrAll) val hint = operator.toLowerCase + val managedName = NameManager.current.get(name, hint) /* TODO - name = NameManager.current.get(name, hint) s._compose(*args, name = name, **symbol_kwargs) */ s From 12b11a788d1f4791b2a56806a3a1ea763adc3b11 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 2 Jan 2016 16:16:12 +0800 Subject: [PATCH 4/6] Symbol creator partially finished. AttrScopeSuite using Symbol.Variable --- .../main/scala/ml/dmlc/mxnet/AttrScope.scala | 12 +++-- .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 9 ++++ .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 47 +++++++++++++++++-- .../scala/ml/dmlc/mxnet/AttrScopeSuite.scala | 19 ++++++++ .../main/native/ml_dmlc_mxnet_native_c_api.cc | 45 ++++++++++++++++++ 5 files changed, 126 insertions(+), 6 deletions(-) create mode 100644 scala-package/core/src/test/scala/ml/dmlc/mxnet/AttrScopeSuite.scala diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala index 8631849f6ccf..5bead65c379b 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala @@ -6,6 +6,8 @@ object AttrScope { private def setCurrentAttr(attr: AttrScope): Unit = { _current = attr } + + def apply(attr: Map[String, String] = Map.empty): AttrScope = new AttrScope(attr) } /** @@ -13,7 +15,7 @@ object AttrScope { * User can also inherit this object to change naming behavior. * @author Yizhi Liu */ -class AttrScope(val attr: Map[String, String] = Map.empty) { +class AttrScope(attr: Map[String, String] = Map.empty) { private var _attr = attr /** * Get the attribute dict given the attribute set by the symbol. @@ -21,12 +23,16 @@ class AttrScope(val attr: Map[String, String] = Map.empty) { * @return Updated attributes to add other scope related attributes. */ def get(userDefinedAttr: Map[String, String]): Map[String, String] = { - attr ++ userDefinedAttr + if (userDefinedAttr != null) { + attr ++ userDefinedAttr + } else { + attr + } } def withScope[T](body: => T): T = { val oldAttrScope = AttrScope.current - this._attr = AttrScope.current.attr ++ this._attr + this._attr = AttrScope.current._attr ++ this._attr AttrScope.setCurrentAttr(this) try { body diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 653520480efd..62dca8cfb629 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -136,4 +136,13 @@ class LibInfo { paramVals: Array[String], symHandleRef: SymbolHandleRef): Int @native def mxSymbolSetAttr(handle: SymbolHandle, key: String, value: String): Int + @native def mxSymbolCompose(handle: SymbolHandle, + name: String, + keys: Array[String], + args: Array[SymbolHandle]): Int + @native def mxSymbolCreateVariable(name: String, out: SymbolHandleRef): Int + @native def mxSymbolGetAttr(handle: SymbolHandle, + key: String, + ret: RefString, + success: RefInt): Int } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index b4ea5bb9d8cf..4bbe3e1c8049 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -9,6 +9,21 @@ object Symbol { private val logger = LoggerFactory.getLogger(classOf[Symbol]) private val functions: Map[String, SymbolFunction] = initSymbolModule() + /** + * Create a symbolic variable with specified name. + * @param name Name of the variable. + * @param attr dict of string -> string + Additional attributes to set on the variable. + * @return The created variable symbol. + */ + def Variable(name: String, attr: Map[String, String] = null): Symbol = { + val handle = new SymbolHandleRef + checkCall(_LIB.mxSymbolCreateVariable(name, handle)) + val sym = new Symbol(handle.value) + sym.setAttr(AttrScope.current.get(attr)) + sym + } + // List and add all the atomic symbol functions to current module. private def initSymbolModule(): Map[String, SymbolFunction] = { val symbolList = ListBuffer.empty[SymbolHandle] @@ -71,9 +86,7 @@ object Symbol { s.setAttr(attrAll) val hint = operator.toLowerCase val managedName = NameManager.current.get(name, hint) - /* TODO - s._compose(*args, name = name, **symbol_kwargs) - */ + s.compose(name = managedName, symbols.toArray) s } } @@ -102,12 +115,40 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { */ def listAuxiliaryStates(): Array[String] = ??? + /** + * Get attribute string from the symbol, this function only works for non-grouped symbol. + * @param key The key to get attribute from. + * @return value The attribute value of the key, returns None if attribute do not exist. + */ + def attr(key: String): Option[String] = { + val ret = new RefString + val success = new RefInt + checkCall(_LIB.mxSymbolGetAttr(handle, key, ret, success)) + if (success.value != 0) { + Option(ret.value) + } else { + None + } + } + // Set the attribute of the symbol. private def setAttr(attr: Map[String, String]): Unit = { attr.foreach { case (key, value) => checkCall(_LIB.mxSymbolSetAttr(handle, key, value)) } } + + /** + * Compose symbol on inputs. + * This call mutates the current symbol. + * @param symbols provide positional arguments + * @return the resulting symbol + */ + private def compose(name: String, symbols: Array[Symbol]): Unit = { + val args = symbols.map(_.handle) + checkCall(_LIB.mxSymbolCompose(handle, name, null, args)) + } + } case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/AttrScopeSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/AttrScopeSuite.scala new file mode 100644 index 000000000000..3e320ff29681 --- /dev/null +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/AttrScopeSuite.scala @@ -0,0 +1,19 @@ +package ml.dmlc.mxnet + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class AttrScopeSuite extends FunSuite with BeforeAndAfterAll { + test("attr basic") { + val (data, gdata) = + AttrScope(Map("group" -> "4", "data" -> "great")).withScope { + val data = Symbol.Variable("data", attr = Map("dtype" -> "data", "group" -> "1")) + val gdata = Symbol.Variable("data2") + (data, gdata) + } + assert(gdata.attr("group").get === "4") + assert(data.attr("group").get === "1") + + val exceedScopeData = Symbol.Variable("data3") + assert(exceedScopeData.attr("group") === None, "No group attr in global attr scope") + } +} diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index e3c71cb1c82a..fb2758e3cb6e 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -773,3 +773,48 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolSetAttr env->ReleaseStringUTFChars(jvalue, cvalue); return ret; } + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCompose + (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jname, + jobjectArray jkeys, jlongArray jargs) { + int argSize = env->GetArrayLength(jargs); + const char **keys = NULL; + if (jkeys != NULL) { + // TODO + } + jlong *args = env->GetLongArrayElements(jargs, NULL); + const char *name = env->GetStringUTFChars(jname, 0); + int ret = MXSymbolCompose((SymbolHandle) symbolPtr, + name, (mx_uint) argSize, keys, + (SymbolHandle*) args); + if (jkeys != NULL) { + // TODO + } + env->ReleaseStringUTFChars(jname, name); + env->ReleaseLongArrayElements(jargs, args, 0); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCreateVariable + (JNIEnv *env, jobject obj, jstring jname, jobject handle) { + SymbolHandle out; + const char *name = env->GetStringUTFChars(jname, 0); + int ret = MXSymbolCreateVariable(name, &out); + env->ReleaseStringUTFChars(jname, name); + setLongField(env, handle, (long)out); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolGetAttr + (JNIEnv *env, jobject obj, jlong symbolPtr, jstring jkey, jobject retRef, jobject successRef) { + + const char *out; + int success; + const char *key = env->GetStringUTFChars(jkey, 0); + int ret = MXSymbolGetAttr((SymbolHandle) symbolPtr, key, &out, &success); + env->ReleaseStringUTFChars(jkey, key); + + setStringField(env, retRef, out); + setIntField(env, successRef, success); + return ret; +} From a41016e749441b21a2e7c79e70203c9251ec6952 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 3 Jan 2016 00:01:12 +0800 Subject: [PATCH 5/6] Symbol plus tested, but not assert final result --- .../main/scala/ml/dmlc/mxnet/AttrScope.scala | 20 +-- .../scala/ml/dmlc/mxnet/NameManager.scala | 19 +- .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 170 +++++++++++------- .../scala/ml/dmlc/mxnet/SymbolSuite.scala | 12 ++ .../main/native/ml_dmlc_mxnet_native_c_api.cc | 14 +- 5 files changed, 149 insertions(+), 86 deletions(-) create mode 100644 scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala index 5bead65c379b..03c6bba384c9 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala @@ -1,15 +1,5 @@ package ml.dmlc.mxnet -object AttrScope { - private var _current = new AttrScope() - def current: AttrScope = _current - private def setCurrentAttr(attr: AttrScope): Unit = { - _current = attr - } - - def apply(attr: Map[String, String] = Map.empty): AttrScope = new AttrScope(attr) -} - /** * Attribute manager for scoping. * User can also inherit this object to change naming behavior. @@ -41,3 +31,13 @@ class AttrScope(attr: Map[String, String] = Map.empty) { } } } + +object AttrScope { + private var _current = new AttrScope() + def current: AttrScope = _current + private def setCurrentAttr(attr: AttrScope): Unit = { + _current = attr + } + + def apply(attr: Map[String, String] = Map.empty): AttrScope = new AttrScope(attr) +} diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala index 3fff60b80a13..3cb536ab264f 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala @@ -11,16 +11,13 @@ class NameManager { val counter: mutable.Map[String, Int] = mutable.HashMap.empty[String, Int] /** * Get the canonical name for a symbol. - * This is default implementation. - * When user specified a name, - * the user specified name will be used. + * This is default implementation. + * When user specified a name, + * the user specified name will be used. + * When user did not, we will automatically generate a name based on hint string. * - * When user did not, we will automatically generate a - * name based on hint string. - * @param name : str or None - The name user specified. - * @param hint : str - A hint string, which can be used to generate name. + * @param name : The name user specified. + * @param hint : A hint string, which can be used to generate name. * @return A canonical name for the user. */ def get(name: String, hint: String): String = { @@ -30,9 +27,9 @@ class NameManager { if (!counter.contains(hint)) { counter(hint) = 0 } - val name = s"$hint${counter(hint)}" + val generatedName = s"$hint${counter(hint)}" counter(hint) += 1 - name + generatedName } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 4bbe3e1c8049..7a8699ee8a25 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -5,6 +5,75 @@ import org.slf4j.LoggerFactory import scala.collection.mutable.ListBuffer +/** + * Symbolic configuration API of mxnet. + * @author Yizhi Liu + */ +class Symbol(private[mxnet] val handle: SymbolHandle) { + def +(other: Symbol): Symbol = Symbol.creator("_Plus", other) + def +(other: Int): Symbol = ??? + def +(other: Float): Symbol = ??? + def +(other: Double): Symbol = ??? + + /** + * List all the arguments in the symbol. + * @return Array of all the arguments. + */ + def listArguments(): Array[String] = ??? + + /** + * List all auxiliary states in the symbol. + * @return The names of the auxiliary states. + * Notes + * ----- + * Auxiliary states are special states of symbols that do not corresponds to an argument, + * and do not have gradient. But still be useful for the specific operations. + * A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm. + * Most operators do not have Auxiliary states. + */ + def listAuxiliaryStates(): Array[String] = ??? + + /** + * Get attribute string from the symbol, this function only works for non-grouped symbol. + * @param key The key to get attribute from. + * @return value The attribute value of the key, returns None if attribute do not exist. + */ + def attr(key: String): Option[String] = { + val ret = new RefString + val success = new RefInt + checkCall(_LIB.mxSymbolGetAttr(handle, key, ret, success)) + if (success.value != 0) { + Option(ret.value) + } else { + None + } + } + + // Set the attribute of the symbol. + private def setAttr(attr: Map[String, String]): Unit = { + attr.foreach { case (key, value) => + checkCall(_LIB.mxSymbolSetAttr(handle, key, value)) + } + } + + /** + * Compose symbol on inputs. + * This call mutates the current symbol. + * @param symbols provide positional arguments + * @return the resulting symbol + */ + private def compose(name: String, symbols: Array[Symbol]): Unit = { + val args = symbols.map(_.handle) + checkCall(_LIB.mxSymbolCompose(handle, name, null, args)) + } + + private def compose(name: String, symbols: Map[String, Symbol]): Unit = { + val keys = symbols.keys.toArray + val args = symbols.values.map(_.handle).toArray + checkCall(_LIB.mxSymbolCompose(handle, name, null, args)) + } +} + object Symbol { private val logger = LoggerFactory.getLogger(classOf[Symbol]) private val functions: Map[String, SymbolFunction] = initSymbolModule() @@ -12,8 +81,7 @@ object Symbol { /** * Create a symbolic variable with specified name. * @param name Name of the variable. - * @param attr dict of string -> string - Additional attributes to set on the variable. + * @param attr Additional attributes to set on the variable. * @return The created variable symbol. */ def Variable(name: String, attr: Map[String, String] = null): Symbol = { @@ -56,16 +124,17 @@ object Symbol { * // TODO * @return the resulting symbol */ - def creator(operator: String, - name: String, - attr: Map[String, String], - paramKwargs: Map[String, String], - symbols: Symbol*): Symbol = { + private def creator(operator: String, + name: String, + attr: Map[String, String], + paramKwargs: Map[String, String], + symbols: Symbol*): Symbol = { val function = functions(operator) require(function != null, s"invalid operator name $operator") - val addkeyVarNumArgs = - (function.keyVarNumArgs != null) && !paramKwargs.contains(function.keyVarNumArgs) + val addkeyVarNumArgs = (function.keyVarNumArgs != null + && !function.keyVarNumArgs.isEmpty + && !paramKwargs.contains(function.keyVarNumArgs)) val paramKeys: Array[String] = ( if (addkeyVarNumArgs) Array[String](function.keyVarNumArgs) @@ -86,69 +155,44 @@ object Symbol { s.setAttr(attrAll) val hint = operator.toLowerCase val managedName = NameManager.current.get(name, hint) - s.compose(name = managedName, symbols.toArray) + s.compose(managedName, symbols.toArray) s } -} -class Symbol(private[mxnet] val handle: SymbolHandle) { - def +(other: Symbol): Symbol = ??? - def +(other: Int): Symbol = ??? - def +(other: Float): Symbol = ??? - def +(other: Double): Symbol = ??? - - /** - * List all the arguments in the symbol. - * @return Array of all the arguments. - */ - def listArguments(): Array[String] = ??? + private def creator(operator: String, symbols: Symbol*): Symbol = { + creator(operator, null, null, Map.empty[String, String], symbols:_*) + } - /** - * List all auxiliary states in the symbol. - * @return The names of the auxiliary states. - * Notes - * ----- - * Auxiliary states are special states of symbols that do not corresponds to an argument, - * and do not have gradient. But still be useful for the specific operations. - * A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm. - * Most operators do not have Auxiliary states. - */ - def listAuxiliaryStates(): Array[String] = ??? + private def creator(operator: String, + name: String, + attr: Map[String, String], + paramKwargs: Map[String, String], + symbols: Map[String, Symbol]): Symbol = { + val function = functions(operator) + require(function != null, s"invalid operator name $operator") + require(function.keyVarNumArgs == null || function.keyVarNumArgs.isEmpty, + "This function support variable length of Symbol arguments.\n" + + "Please pass all the input Symbols via positional arguments instead of keyword arguments.") - /** - * Get attribute string from the symbol, this function only works for non-grouped symbol. - * @param key The key to get attribute from. - * @return value The attribute value of the key, returns None if attribute do not exist. - */ - def attr(key: String): Option[String] = { - val ret = new RefString - val success = new RefInt - checkCall(_LIB.mxSymbolGetAttr(handle, key, ret, success)) - if (success.value != 0) { - Option(ret.value) - } else { - None - } - } + val paramKeys = paramKwargs.keys.toArray + val paramVals = paramKwargs.values.toArray + val symHandle = new SymbolHandleRef + checkCall(_LIB.mxSymbolCreateAtomicSymbol( + function.handle, paramKeys, paramVals, symHandle)) - // Set the attribute of the symbol. - private def setAttr(attr: Map[String, String]): Unit = { - attr.foreach { case (key, value) => - checkCall(_LIB.mxSymbolSetAttr(handle, key, value)) - } + val s = new Symbol(symHandle.value) + val attrAll = AttrScope.current.get(attr) + s.setAttr(attrAll) + val hint = operator.toLowerCase + val managedName = NameManager.current.get(name, hint) + s.compose(managedName, symbols) + s } - /** - * Compose symbol on inputs. - * This call mutates the current symbol. - * @param symbols provide positional arguments - * @return the resulting symbol - */ - private def compose(name: String, symbols: Array[Symbol]): Unit = { - val args = symbols.map(_.handle) - checkCall(_LIB.mxSymbolCompose(handle, name, null, args)) + private def creator(operator: String, symbols: Map[String, Symbol]): Symbol = { + creator(operator, null, null, Map.empty[String, String], symbols) } } -case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String) +private case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala new file mode 100644 index 000000000000..6a3b02a23d60 --- /dev/null +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala @@ -0,0 +1,12 @@ +package ml.dmlc.mxnet + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class SymbolSuite extends FunSuite with BeforeAndAfterAll { + test("plus") { + val sym1 = Symbol.Variable("data1") + val sym2 = Symbol.Variable("data2") + val symPlus = sym1 + sym2 + // TODO: check result + } +} diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index fb2758e3cb6e..508f26f8378c 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -780,15 +780,25 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCompose int argSize = env->GetArrayLength(jargs); const char **keys = NULL; if (jkeys != NULL) { - // TODO + keys = new const char*[argSize]; + for (int i = 0; i < argSize; i++) { + jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); + const char *key = env->GetStringUTFChars(jkey, 0); + keys[i] = key; + } } jlong *args = env->GetLongArrayElements(jargs, NULL); const char *name = env->GetStringUTFChars(jname, 0); int ret = MXSymbolCompose((SymbolHandle) symbolPtr, name, (mx_uint) argSize, keys, (SymbolHandle*) args); + // release allocated memory if (jkeys != NULL) { - // TODO + for (int i = 0; i < argSize; i++) { + jstring jkey = (jstring) env->GetObjectArrayElement(jkeys, i); + env->ReleaseStringUTFChars(jkey, keys[i]); + } + delete[] keys; } env->ReleaseStringUTFChars(jname, name); env->ReleaseLongArrayElements(jargs, args, 0); From e1aa3865231eb31d8b3095d5b7ef0f12ba92e1d3 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 3 Jan 2016 18:30:14 +0800 Subject: [PATCH 6/6] FullyConnected, Activate and Group Symbol creation --- .../main/scala/ml/dmlc/mxnet/AttrScope.scala | 8 +- .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 7 + .../scala/ml/dmlc/mxnet/NameManager.scala | 6 +- .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 170 ++++++++++++++---- .../scala/ml/dmlc/mxnet/SymbolSuite.scala | 26 ++- .../main/native/ml_dmlc_mxnet_native_c_api.cc | 72 +++++++- scala-package/pom.xml | 5 + 7 files changed, 239 insertions(+), 55 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala index 03c6bba384c9..2d4b678c01d7 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala @@ -12,12 +12,8 @@ class AttrScope(attr: Map[String, String] = Map.empty) { * @param userDefinedAttr The attribute passed in by user during symbol creation. * @return Updated attributes to add other scope related attributes. */ - def get(userDefinedAttr: Map[String, String]): Map[String, String] = { - if (userDefinedAttr != null) { - attr ++ userDefinedAttr - } else { - attr - } + def get(userDefinedAttr: Option[Map[String, String]]): Map[String, String] = { + _attr ++ userDefinedAttr.getOrElse(Map.empty[String, String]) } def withScope[T](body: => T): T = { diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 62dca8cfb629..0837e6d1ca80 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -145,4 +145,11 @@ class LibInfo { key: String, ret: RefString, success: RefInt): Int + @native def mxSymbolListArguments(handle: SymbolHandle, + arguments: ArrayBuffer[String]): Int + @native def mxSymbolCopy(handle: SymbolHandle, clonedHandle: SymbolHandleRef): Int + @native def mxSymbolListOutputs(handle: SymbolHandle, + outputs: ArrayBuffer[String]): Int + @native def mxSymbolCreateGroup(handles: Array[SymbolHandle], out: SymbolHandleRef): Int + @native def mxSymbolPrint(handle: SymbolHandle, str: RefString): Int } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala index 3cb536ab264f..f81af5ed1724 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala @@ -20,10 +20,8 @@ class NameManager { * @param hint : A hint string, which can be used to generate name. * @return A canonical name for the user. */ - def get(name: String, hint: String): String = { - if (name != null) { - name - } else { + def get(name: Option[String], hint: String): String = { + name.getOrElse { if (!counter.contains(hint)) { counter(hint) = 0 } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 7a8699ee8a25..33e10ad5ea7e 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -3,23 +3,40 @@ package ml.dmlc.mxnet import ml.dmlc.mxnet.Base._ import org.slf4j.LoggerFactory -import scala.collection.mutable.ListBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} /** * Symbolic configuration API of mxnet. * @author Yizhi Liu */ class Symbol(private[mxnet] val handle: SymbolHandle) { - def +(other: Symbol): Symbol = Symbol.creator("_Plus", other) - def +(other: Int): Symbol = ??? - def +(other: Float): Symbol = ??? - def +(other: Double): Symbol = ??? + def +(other: Symbol): Symbol = Symbol.create("_Plus", other) + + override def clone(): Symbol = { + val clonedHandle = new SymbolHandleRef + checkCall(_LIB.mxSymbolCopy(handle, clonedHandle)) + new Symbol(clonedHandle.value) + } /** * List all the arguments in the symbol. * @return Array of all the arguments. */ - def listArguments(): Array[String] = ??? + def listArguments(): Array[String] = { + val arr = ArrayBuffer.empty[String] + checkCall(_LIB.mxSymbolListArguments(handle, arr)) + arr.toArray + } + + /** + * List all outputs in the symbol. + * @return : List of all the outputs. + */ + def listOutputs(): Array[String] = { + val arr = ArrayBuffer.empty[String] + checkCall(_LIB.mxSymbolListOutputs(handle, arr)) + arr.toArray + } /** * List all auxiliary states in the symbol. @@ -49,6 +66,28 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { } } + /** + * Invoke symbol as function on inputs. + * @param name resulting symbol name + * @param symbols provide named symbols + * @return the resulting symbol + */ + def apply(name: String, symbols: Map[String, Symbol]): Symbol = { + val s = clone() + s.compose(name, symbols) + s + } + + /** + * Get a debug string. + * @return Debug string of the symbol. + */ + def debugStr: String = { + val str = new RefString + checkCall(_LIB.mxSymbolPrint(handle, str)) + str.value + } + // Set the attribute of the symbol. private def setAttr(attr: Map[String, String]): Unit = { attr.foreach { case (key, value) => @@ -59,6 +98,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { /** * Compose symbol on inputs. * This call mutates the current symbol. + * @param name resulting symbol name * @param symbols provide positional arguments * @return the resulting symbol */ @@ -70,7 +110,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { private def compose(name: String, symbols: Map[String, Symbol]): Unit = { val keys = symbols.keys.toArray val args = symbols.values.map(_.handle).toArray - checkCall(_LIB.mxSymbolCompose(handle, name, null, args)) + checkCall(_LIB.mxSymbolCompose(handle, name, keys, args)) } } @@ -88,10 +128,38 @@ object Symbol { val handle = new SymbolHandleRef checkCall(_LIB.mxSymbolCreateVariable(name, handle)) val sym = new Symbol(handle.value) - sym.setAttr(AttrScope.current.get(attr)) + sym.setAttr(AttrScope.current.get(Option(attr))) sym } + def FullyConnected: Map[String, Any] => Symbol = { + FullyConnected(null) + } + + def FullyConnected(attr: Map[String, String]): Map[String, Any] => Symbol = { + createNoCheck("FullyConnected", attr) + } + + def Activation: Map[String, Any] => Symbol = { + Activation(null) + } + + def Activation(attr: Map[String, String]): Map[String, Any] => Symbol = { + createNoCheck("Activation", attr) + } + + /** + * Create a symbol that groups symbols together. + * @param symbols List of symbols to be grouped. + * @return The created group symbol. + */ + def Group(symbols: Symbol*): Symbol = { + val ihandles = symbols.map(_.handle).toArray + val handle = new SymbolHandleRef + checkCall(_LIB.mxSymbolCreateGroup(ihandles, handle)) + new Symbol(handle.value) + } + // List and add all the atomic symbol functions to current module. private def initSymbolModule(): Map[String, SymbolFunction] = { val symbolList = ListBuffer.empty[SymbolHandle] @@ -120,30 +188,31 @@ object Symbol { /** * Activation Operator of Neural Net. * The parameters listed below can be passed in as keyword arguments. - * @param name Name of the resulting symbol. - * // TODO + * @param symbols Symbol parameters passed to create the resulting symbol + * @param paramKwargs Key-value parameters passed to create the resulting symbol + * @param attr Attributes set to the resulting symbol * @return the resulting symbol */ - private def creator(operator: String, - name: String, - attr: Map[String, String], - paramKwargs: Map[String, String], - symbols: Symbol*): Symbol = { + def create(operator: String, + symbols: Array[Symbol], + paramKwargs: Map[String, String], + attr: Map[String, String]): Symbol = { val function = functions(operator) require(function != null, s"invalid operator name $operator") + val params = if (paramKwargs == null) Map.empty[String, String] else paramKwargs val addkeyVarNumArgs = (function.keyVarNumArgs != null && !function.keyVarNumArgs.isEmpty - && !paramKwargs.contains(function.keyVarNumArgs)) + && !params.contains(function.keyVarNumArgs)) val paramKeys: Array[String] = ( if (addkeyVarNumArgs) Array[String](function.keyVarNumArgs) else Array.empty[String] - ) ++ paramKwargs.keys + ) ++ (params - "name").keys val paramVals: Array[String] = ( if (addkeyVarNumArgs) Array[String](symbols.length.toString) else Array.empty[String] - ) ++ paramKwargs.values + ) ++ (params - "name").values // create atomic symbol val symHandle = new SymbolHandleRef @@ -151,48 +220,81 @@ object Symbol { function.handle, paramKeys, paramVals, symHandle)) val s = new Symbol(symHandle.value) - val attrAll = AttrScope.current.get(attr) + val attrAll = AttrScope.current.get(Option(attr)) s.setAttr(attrAll) val hint = operator.toLowerCase - val managedName = NameManager.current.get(name, hint) - s.compose(managedName, symbols.toArray) + val managedName = NameManager.current.get(params.get("name"), hint) + s.compose(managedName, symbols) s } - private def creator(operator: String, symbols: Symbol*): Symbol = { - creator(operator, null, null, Map.empty[String, String], symbols:_*) + def create(operator: String, symbols: Symbol*): Symbol = { + create(operator, symbols.toArray, null, null) } - private def creator(operator: String, - name: String, - attr: Map[String, String], - paramKwargs: Map[String, String], - symbols: Map[String, Symbol]): Symbol = { + /** + * Activation Operator of Neural Net. + * The parameters listed below can be passed in as keyword arguments. + * @param symbols Named symbol parameters passed to create the resulting symbol + * @param paramKwargs Key-value parameters passed to create the resulting symbol + * @param attr Attributes set to the resulting symbol + * @return the resulting symbol + */ + private def create(operator: String, + symbols: Map[String, Symbol], + paramKwargs: Map[String, String], + attr: Map[String, String]): Symbol = { val function = functions(operator) require(function != null, s"invalid operator name $operator") require(function.keyVarNumArgs == null || function.keyVarNumArgs.isEmpty, "This function support variable length of Symbol arguments.\n" + "Please pass all the input Symbols via positional arguments instead of keyword arguments.") - val paramKeys = paramKwargs.keys.toArray - val paramVals = paramKwargs.values.toArray + val paramKeys = + if (paramKwargs == null) Array.empty[String] + else (paramKwargs - "name").keys.toArray + val paramVals = + if (paramKwargs == null) Array.empty[String] + else (paramKwargs - "name").values.toArray val symHandle = new SymbolHandleRef checkCall(_LIB.mxSymbolCreateAtomicSymbol( function.handle, paramKeys, paramVals, symHandle)) val s = new Symbol(symHandle.value) - val attrAll = AttrScope.current.get(attr) + val attrAll = AttrScope.current.get(Option(attr)) s.setAttr(attrAll) val hint = operator.toLowerCase - val managedName = NameManager.current.get(name, hint) + val managedName = NameManager.current.get(paramKwargs.get("name"), hint) s.compose(managedName, symbols) s } - private def creator(operator: String, symbols: Map[String, Symbol]): Symbol = { - creator(operator, null, null, Map.empty[String, String], symbols) + def create(operator: String, symbols: Map[String, Symbol]): Symbol = { + create(operator, symbols, null, null) } + def create(operator: String, + symbols: Map[String, Symbol], + paramKwargs: Map[String, String]): Symbol = { + create(operator, symbols, paramKwargs, null) + } + + // a more friendly interface for creating symbols + // all values except symbols in kwargs will be cast to String using its toString() method + def createNoCheck(operator: String, attr: Map[String, String] = null)( + kwargs: Map[String, Any]): Symbol = { + val symbolArgs = kwargs.filter { case (key, value) => + value.isInstanceOf[Symbol] + }.map { case (key, value) => + (key, value.asInstanceOf[Symbol]) + } + val strArgs = kwargs.filter { case (key, value) => + !value.isInstanceOf[Symbol] + }.map { case (key, value) => + (key, value.toString) + } + create(operator, symbolArgs, strArgs, attr) + } } private case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala index 6a3b02a23d60..edc6ace47444 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala @@ -3,10 +3,26 @@ package ml.dmlc.mxnet import org.scalatest.{BeforeAndAfterAll, FunSuite} class SymbolSuite extends FunSuite with BeforeAndAfterAll { - test("plus") { - val sym1 = Symbol.Variable("data1") - val sym2 = Symbol.Variable("data2") - val symPlus = sym1 + sym2 - // TODO: check result + test("symbol compose") { + val data = Symbol.Variable("data") + + var net1 = Symbol.FullyConnected(Map("data" -> data, "name" -> "fc1", "num_hidden" -> 10)) + net1 = Symbol.FullyConnected(Map("data" -> net1, "name" -> "fc2", "num_hidden" -> 100)) + assert(net1.listArguments() === + Array("data", "fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias")) + + var net2 = Symbol.FullyConnected(Map("name" -> "fc3", "num_hidden" -> 10)) + net2 = Symbol.Activation(Map("data" -> net2, "act_type" -> "relu")) + net2 = Symbol.FullyConnected(Map("data" -> net2, "name" -> "fc4", "num_hidden" -> 20)) + // scalastyle:off println + println(s"net2 debug info:\n${net2.debugStr}") + // scalastyle:on println + + val composed = net2(name = "composed", Map("fc3_data" -> net1)) + // scalastyle:off println + println(s"composed debug info:\n${composed.debugStr}") + // scalastyle:on println + val multiOut = Symbol.Group(composed, net1) + assert(multiOut.listOutputs().length === 2) } } diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 508f26f8378c..f6950d097aa6 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -214,6 +214,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySyncCopyFromCPU return ret; } +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree + (JNIEnv * env, jobject obj, jobject ndArrayHandle) { + return MXNDArrayFree((NDArrayHandle) getLongField(env, ndArrayHandle)); +} + // The related c api MXKVStoreSetUpdater function takes a c function pointer as its parameter, // while we write java functions here in scala-package. // Thus we have to wrap the function in a java object, and run env->CallVoidMethod(obj) once updater is invoked, @@ -484,12 +489,6 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env return rtstr; } -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree(JNIEnv * env, jobject obj, jobject ndArrayHandle) { - // TODO - puts("Free ndarray called"); - return 0; -} - //IO funcs JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters (JNIEnv * env, jobject obj, jobject creators) { @@ -828,3 +827,64 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolGetAttr setIntField(env, successRef, success); return ret; } + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListArguments + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject arguments) { + mx_uint outSize; + const char **outStrArray; + int ret = MXSymbolListArguments((SymbolHandle) symbolPtr, &outSize, &outStrArray); + + jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer"); + jmethodID arrayAppend = env->GetMethodID(arrayClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;"); + for (int i = 0; i < outSize; i++) { + jstring argument = env->NewStringUTF(outStrArray[i]); + env->CallObjectMethod(arguments, arrayAppend, argument); + } + + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListOutputs + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject outputs) { + mx_uint outSize; + const char **outStrArray; + int ret = MXSymbolListOutputs((SymbolHandle) symbolPtr, &outSize, &outStrArray); + + jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer"); + jmethodID arrayAppend = env->GetMethodID(arrayClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;"); + for (int i = 0; i < outSize; i++) { + jstring output = env->NewStringUTF(outStrArray[i]); + env->CallObjectMethod(outputs, arrayAppend, output); + } + + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCopy + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject clonedSymbolRef) { + SymbolHandle clonedSymbol; + int ret = MXSymbolCopy((SymbolHandle) symbolPtr, &clonedSymbol); + setLongField(env, clonedSymbolRef, (long)clonedSymbol); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCreateGroup + (JNIEnv *env, jobject obj, jlongArray jsymbols, jobject out) { + int numSymbols = env->GetArrayLength(jsymbols); + SymbolHandle handle; + jlong *symbols = env->GetLongArrayElements(jsymbols, NULL); + int ret = MXSymbolCreateGroup(numSymbols, (SymbolHandle *)symbols, &handle); + env->ReleaseLongArrayElements(jsymbols, symbols, 0); + setLongField(env, out, (long)handle); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolPrint + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject out) { + const char *outStr; + int ret = MXSymbolPrint((SymbolHandle) symbolPtr, &outStr); + setStringField(env, out, outStr); + return ret; +} diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 524c9a52eb0f..3e2a9a2494be 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -202,6 +202,11 @@ scala-library ${scala.version} + + org.scala-lang + scala-reflect + ${scala.version} + commons-codec commons-codec