diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala index 03c6bba384c9..2d4b678c01d7 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala @@ -12,12 +12,8 @@ class AttrScope(attr: Map[String, String] = Map.empty) { * @param userDefinedAttr The attribute passed in by user during symbol creation. * @return Updated attributes to add other scope related attributes. */ - def get(userDefinedAttr: Map[String, String]): Map[String, String] = { - if (userDefinedAttr != null) { - attr ++ userDefinedAttr - } else { - attr - } + def get(userDefinedAttr: Option[Map[String, String]]): Map[String, String] = { + _attr ++ userDefinedAttr.getOrElse(Map.empty[String, String]) } def withScope[T](body: => T): T = { diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 62dca8cfb629..0837e6d1ca80 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -145,4 +145,11 @@ class LibInfo { key: String, ret: RefString, success: RefInt): Int + @native def mxSymbolListArguments(handle: SymbolHandle, + arguments: ArrayBuffer[String]): Int + @native def mxSymbolCopy(handle: SymbolHandle, clonedHandle: SymbolHandleRef): Int + @native def mxSymbolListOutputs(handle: SymbolHandle, + outputs: ArrayBuffer[String]): Int + @native def mxSymbolCreateGroup(handles: Array[SymbolHandle], out: SymbolHandleRef): Int + @native def mxSymbolPrint(handle: SymbolHandle, str: RefString): Int } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala index 3cb536ab264f..f81af5ed1724 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala @@ -20,10 +20,8 @@ class NameManager { * @param hint : A hint string, which can be used to generate name. * @return A canonical name for the user. */ - def get(name: String, hint: String): String = { - if (name != null) { - name - } else { + def get(name: Option[String], hint: String): String = { + name.getOrElse { if (!counter.contains(hint)) { counter(hint) = 0 } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 7a8699ee8a25..33e10ad5ea7e 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -3,23 +3,40 @@ package ml.dmlc.mxnet import ml.dmlc.mxnet.Base._ import org.slf4j.LoggerFactory -import scala.collection.mutable.ListBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} /** * Symbolic configuration API of mxnet. * @author Yizhi Liu */ class Symbol(private[mxnet] val handle: SymbolHandle) { - def +(other: Symbol): Symbol = Symbol.creator("_Plus", other) - def +(other: Int): Symbol = ??? - def +(other: Float): Symbol = ??? - def +(other: Double): Symbol = ??? + def +(other: Symbol): Symbol = Symbol.create("_Plus", other) + + override def clone(): Symbol = { + val clonedHandle = new SymbolHandleRef + checkCall(_LIB.mxSymbolCopy(handle, clonedHandle)) + new Symbol(clonedHandle.value) + } /** * List all the arguments in the symbol. * @return Array of all the arguments. */ - def listArguments(): Array[String] = ??? + def listArguments(): Array[String] = { + val arr = ArrayBuffer.empty[String] + checkCall(_LIB.mxSymbolListArguments(handle, arr)) + arr.toArray + } + + /** + * List all outputs in the symbol. + * @return : List of all the outputs. + */ + def listOutputs(): Array[String] = { + val arr = ArrayBuffer.empty[String] + checkCall(_LIB.mxSymbolListOutputs(handle, arr)) + arr.toArray + } /** * List all auxiliary states in the symbol. @@ -49,6 +66,28 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { } } + /** + * Invoke symbol as function on inputs. + * @param name resulting symbol name + * @param symbols provide named symbols + * @return the resulting symbol + */ + def apply(name: String, symbols: Map[String, Symbol]): Symbol = { + val s = clone() + s.compose(name, symbols) + s + } + + /** + * Get a debug string. + * @return Debug string of the symbol. + */ + def debugStr: String = { + val str = new RefString + checkCall(_LIB.mxSymbolPrint(handle, str)) + str.value + } + // Set the attribute of the symbol. private def setAttr(attr: Map[String, String]): Unit = { attr.foreach { case (key, value) => @@ -59,6 +98,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { /** * Compose symbol on inputs. * This call mutates the current symbol. + * @param name resulting symbol name * @param symbols provide positional arguments * @return the resulting symbol */ @@ -70,7 +110,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { private def compose(name: String, symbols: Map[String, Symbol]): Unit = { val keys = symbols.keys.toArray val args = symbols.values.map(_.handle).toArray - checkCall(_LIB.mxSymbolCompose(handle, name, null, args)) + checkCall(_LIB.mxSymbolCompose(handle, name, keys, args)) } } @@ -88,10 +128,38 @@ object Symbol { val handle = new SymbolHandleRef checkCall(_LIB.mxSymbolCreateVariable(name, handle)) val sym = new Symbol(handle.value) - sym.setAttr(AttrScope.current.get(attr)) + sym.setAttr(AttrScope.current.get(Option(attr))) sym } + def FullyConnected: Map[String, Any] => Symbol = { + FullyConnected(null) + } + + def FullyConnected(attr: Map[String, String]): Map[String, Any] => Symbol = { + createNoCheck("FullyConnected", attr) + } + + def Activation: Map[String, Any] => Symbol = { + Activation(null) + } + + def Activation(attr: Map[String, String]): Map[String, Any] => Symbol = { + createNoCheck("Activation", attr) + } + + /** + * Create a symbol that groups symbols together. + * @param symbols List of symbols to be grouped. + * @return The created group symbol. + */ + def Group(symbols: Symbol*): Symbol = { + val ihandles = symbols.map(_.handle).toArray + val handle = new SymbolHandleRef + checkCall(_LIB.mxSymbolCreateGroup(ihandles, handle)) + new Symbol(handle.value) + } + // List and add all the atomic symbol functions to current module. private def initSymbolModule(): Map[String, SymbolFunction] = { val symbolList = ListBuffer.empty[SymbolHandle] @@ -120,30 +188,31 @@ object Symbol { /** * Activation Operator of Neural Net. * The parameters listed below can be passed in as keyword arguments. - * @param name Name of the resulting symbol. - * // TODO + * @param symbols Symbol parameters passed to create the resulting symbol + * @param paramKwargs Key-value parameters passed to create the resulting symbol + * @param attr Attributes set to the resulting symbol * @return the resulting symbol */ - private def creator(operator: String, - name: String, - attr: Map[String, String], - paramKwargs: Map[String, String], - symbols: Symbol*): Symbol = { + def create(operator: String, + symbols: Array[Symbol], + paramKwargs: Map[String, String], + attr: Map[String, String]): Symbol = { val function = functions(operator) require(function != null, s"invalid operator name $operator") + val params = if (paramKwargs == null) Map.empty[String, String] else paramKwargs val addkeyVarNumArgs = (function.keyVarNumArgs != null && !function.keyVarNumArgs.isEmpty - && !paramKwargs.contains(function.keyVarNumArgs)) + && !params.contains(function.keyVarNumArgs)) val paramKeys: Array[String] = ( if (addkeyVarNumArgs) Array[String](function.keyVarNumArgs) else Array.empty[String] - ) ++ paramKwargs.keys + ) ++ (params - "name").keys val paramVals: Array[String] = ( if (addkeyVarNumArgs) Array[String](symbols.length.toString) else Array.empty[String] - ) ++ paramKwargs.values + ) ++ (params - "name").values // create atomic symbol val symHandle = new SymbolHandleRef @@ -151,48 +220,81 @@ object Symbol { function.handle, paramKeys, paramVals, symHandle)) val s = new Symbol(symHandle.value) - val attrAll = AttrScope.current.get(attr) + val attrAll = AttrScope.current.get(Option(attr)) s.setAttr(attrAll) val hint = operator.toLowerCase - val managedName = NameManager.current.get(name, hint) - s.compose(managedName, symbols.toArray) + val managedName = NameManager.current.get(params.get("name"), hint) + s.compose(managedName, symbols) s } - private def creator(operator: String, symbols: Symbol*): Symbol = { - creator(operator, null, null, Map.empty[String, String], symbols:_*) + def create(operator: String, symbols: Symbol*): Symbol = { + create(operator, symbols.toArray, null, null) } - private def creator(operator: String, - name: String, - attr: Map[String, String], - paramKwargs: Map[String, String], - symbols: Map[String, Symbol]): Symbol = { + /** + * Activation Operator of Neural Net. + * The parameters listed below can be passed in as keyword arguments. + * @param symbols Named symbol parameters passed to create the resulting symbol + * @param paramKwargs Key-value parameters passed to create the resulting symbol + * @param attr Attributes set to the resulting symbol + * @return the resulting symbol + */ + private def create(operator: String, + symbols: Map[String, Symbol], + paramKwargs: Map[String, String], + attr: Map[String, String]): Symbol = { val function = functions(operator) require(function != null, s"invalid operator name $operator") require(function.keyVarNumArgs == null || function.keyVarNumArgs.isEmpty, "This function support variable length of Symbol arguments.\n" + "Please pass all the input Symbols via positional arguments instead of keyword arguments.") - val paramKeys = paramKwargs.keys.toArray - val paramVals = paramKwargs.values.toArray + val paramKeys = + if (paramKwargs == null) Array.empty[String] + else (paramKwargs - "name").keys.toArray + val paramVals = + if (paramKwargs == null) Array.empty[String] + else (paramKwargs - "name").values.toArray val symHandle = new SymbolHandleRef checkCall(_LIB.mxSymbolCreateAtomicSymbol( function.handle, paramKeys, paramVals, symHandle)) val s = new Symbol(symHandle.value) - val attrAll = AttrScope.current.get(attr) + val attrAll = AttrScope.current.get(Option(attr)) s.setAttr(attrAll) val hint = operator.toLowerCase - val managedName = NameManager.current.get(name, hint) + val managedName = NameManager.current.get(paramKwargs.get("name"), hint) s.compose(managedName, symbols) s } - private def creator(operator: String, symbols: Map[String, Symbol]): Symbol = { - creator(operator, null, null, Map.empty[String, String], symbols) + def create(operator: String, symbols: Map[String, Symbol]): Symbol = { + create(operator, symbols, null, null) } + def create(operator: String, + symbols: Map[String, Symbol], + paramKwargs: Map[String, String]): Symbol = { + create(operator, symbols, paramKwargs, null) + } + + // a more friendly interface for creating symbols + // all values except symbols in kwargs will be cast to String using its toString() method + def createNoCheck(operator: String, attr: Map[String, String] = null)( + kwargs: Map[String, Any]): Symbol = { + val symbolArgs = kwargs.filter { case (key, value) => + value.isInstanceOf[Symbol] + }.map { case (key, value) => + (key, value.asInstanceOf[Symbol]) + } + val strArgs = kwargs.filter { case (key, value) => + !value.isInstanceOf[Symbol] + }.map { case (key, value) => + (key, value.toString) + } + create(operator, symbolArgs, strArgs, attr) + } } private case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala index 6a3b02a23d60..edc6ace47444 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala @@ -3,10 +3,26 @@ package ml.dmlc.mxnet import org.scalatest.{BeforeAndAfterAll, FunSuite} class SymbolSuite extends FunSuite with BeforeAndAfterAll { - test("plus") { - val sym1 = Symbol.Variable("data1") - val sym2 = Symbol.Variable("data2") - val symPlus = sym1 + sym2 - // TODO: check result + test("symbol compose") { + val data = Symbol.Variable("data") + + var net1 = Symbol.FullyConnected(Map("data" -> data, "name" -> "fc1", "num_hidden" -> 10)) + net1 = Symbol.FullyConnected(Map("data" -> net1, "name" -> "fc2", "num_hidden" -> 100)) + assert(net1.listArguments() === + Array("data", "fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias")) + + var net2 = Symbol.FullyConnected(Map("name" -> "fc3", "num_hidden" -> 10)) + net2 = Symbol.Activation(Map("data" -> net2, "act_type" -> "relu")) + net2 = Symbol.FullyConnected(Map("data" -> net2, "name" -> "fc4", "num_hidden" -> 20)) + // scalastyle:off println + println(s"net2 debug info:\n${net2.debugStr}") + // scalastyle:on println + + val composed = net2(name = "composed", Map("fc3_data" -> net1)) + // scalastyle:off println + println(s"composed debug info:\n${composed.debugStr}") + // scalastyle:on println + val multiOut = Symbol.Group(composed, net1) + assert(multiOut.listOutputs().length === 2) } } diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 508f26f8378c..f6950d097aa6 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -214,6 +214,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySyncCopyFromCPU return ret; } +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree + (JNIEnv * env, jobject obj, jobject ndArrayHandle) { + return MXNDArrayFree((NDArrayHandle) getLongField(env, ndArrayHandle)); +} + // The related c api MXKVStoreSetUpdater function takes a c function pointer as its parameter, // while we write java functions here in scala-package. // Thus we have to wrap the function in a java object, and run env->CallVoidMethod(obj) once updater is invoked, @@ -484,12 +489,6 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env return rtstr; } -JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree(JNIEnv * env, jobject obj, jobject ndArrayHandle) { - // TODO - puts("Free ndarray called"); - return 0; -} - //IO funcs JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters (JNIEnv * env, jobject obj, jobject creators) { @@ -828,3 +827,64 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolGetAttr setIntField(env, successRef, success); return ret; } + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListArguments + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject arguments) { + mx_uint outSize; + const char **outStrArray; + int ret = MXSymbolListArguments((SymbolHandle) symbolPtr, &outSize, &outStrArray); + + jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer"); + jmethodID arrayAppend = env->GetMethodID(arrayClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;"); + for (int i = 0; i < outSize; i++) { + jstring argument = env->NewStringUTF(outStrArray[i]); + env->CallObjectMethod(arguments, arrayAppend, argument); + } + + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListOutputs + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject outputs) { + mx_uint outSize; + const char **outStrArray; + int ret = MXSymbolListOutputs((SymbolHandle) symbolPtr, &outSize, &outStrArray); + + jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer"); + jmethodID arrayAppend = env->GetMethodID(arrayClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;"); + for (int i = 0; i < outSize; i++) { + jstring output = env->NewStringUTF(outStrArray[i]); + env->CallObjectMethod(outputs, arrayAppend, output); + } + + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCopy + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject clonedSymbolRef) { + SymbolHandle clonedSymbol; + int ret = MXSymbolCopy((SymbolHandle) symbolPtr, &clonedSymbol); + setLongField(env, clonedSymbolRef, (long)clonedSymbol); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCreateGroup + (JNIEnv *env, jobject obj, jlongArray jsymbols, jobject out) { + int numSymbols = env->GetArrayLength(jsymbols); + SymbolHandle handle; + jlong *symbols = env->GetLongArrayElements(jsymbols, NULL); + int ret = MXSymbolCreateGroup(numSymbols, (SymbolHandle *)symbols, &handle); + env->ReleaseLongArrayElements(jsymbols, symbols, 0); + setLongField(env, out, (long)handle); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolPrint + (JNIEnv *env, jobject obj, jlong symbolPtr, jobject out) { + const char *outStr; + int ret = MXSymbolPrint((SymbolHandle) symbolPtr, &outStr); + setStringField(env, out, outStr); + return ret; +} diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 524c9a52eb0f..3e2a9a2494be 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -202,6 +202,11 @@ scala-library ${scala.version} + + org.scala-lang + scala-reflect + ${scala.version} + commons-codec commons-codec