diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala
index 03c6bba384c9..2d4b678c01d7 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/AttrScope.scala
@@ -12,12 +12,8 @@ class AttrScope(attr: Map[String, String] = Map.empty) {
* @param userDefinedAttr The attribute passed in by user during symbol creation.
* @return Updated attributes to add other scope related attributes.
*/
- def get(userDefinedAttr: Map[String, String]): Map[String, String] = {
- if (userDefinedAttr != null) {
- attr ++ userDefinedAttr
- } else {
- attr
- }
+ def get(userDefinedAttr: Option[Map[String, String]]): Map[String, String] = {
+ _attr ++ userDefinedAttr.getOrElse(Map.empty[String, String])
}
def withScope[T](body: => T): T = {
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
index 62dca8cfb629..0837e6d1ca80 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
@@ -145,4 +145,11 @@ class LibInfo {
key: String,
ret: RefString,
success: RefInt): Int
+ @native def mxSymbolListArguments(handle: SymbolHandle,
+ arguments: ArrayBuffer[String]): Int
+ @native def mxSymbolCopy(handle: SymbolHandle, clonedHandle: SymbolHandleRef): Int
+ @native def mxSymbolListOutputs(handle: SymbolHandle,
+ outputs: ArrayBuffer[String]): Int
+ @native def mxSymbolCreateGroup(handles: Array[SymbolHandle], out: SymbolHandleRef): Int
+ @native def mxSymbolPrint(handle: SymbolHandle, str: RefString): Int
}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala
index 3cb536ab264f..f81af5ed1724 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NameManager.scala
@@ -20,10 +20,8 @@ class NameManager {
* @param hint : A hint string, which can be used to generate name.
* @return A canonical name for the user.
*/
- def get(name: String, hint: String): String = {
- if (name != null) {
- name
- } else {
+ def get(name: Option[String], hint: String): String = {
+ name.getOrElse {
if (!counter.contains(hint)) {
counter(hint) = 0
}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala
index 7a8699ee8a25..33e10ad5ea7e 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala
@@ -3,23 +3,40 @@ package ml.dmlc.mxnet
import ml.dmlc.mxnet.Base._
import org.slf4j.LoggerFactory
-import scala.collection.mutable.ListBuffer
+import scala.collection.mutable.{ArrayBuffer, ListBuffer}
/**
* Symbolic configuration API of mxnet.
* @author Yizhi Liu
*/
class Symbol(private[mxnet] val handle: SymbolHandle) {
- def +(other: Symbol): Symbol = Symbol.creator("_Plus", other)
- def +(other: Int): Symbol = ???
- def +(other: Float): Symbol = ???
- def +(other: Double): Symbol = ???
+ def +(other: Symbol): Symbol = Symbol.create("_Plus", other)
+
+ override def clone(): Symbol = {
+ val clonedHandle = new SymbolHandleRef
+ checkCall(_LIB.mxSymbolCopy(handle, clonedHandle))
+ new Symbol(clonedHandle.value)
+ }
/**
* List all the arguments in the symbol.
* @return Array of all the arguments.
*/
- def listArguments(): Array[String] = ???
+ def listArguments(): Array[String] = {
+ val arr = ArrayBuffer.empty[String]
+ checkCall(_LIB.mxSymbolListArguments(handle, arr))
+ arr.toArray
+ }
+
+ /**
+ * List all outputs in the symbol.
+ * @return : List of all the outputs.
+ */
+ def listOutputs(): Array[String] = {
+ val arr = ArrayBuffer.empty[String]
+ checkCall(_LIB.mxSymbolListOutputs(handle, arr))
+ arr.toArray
+ }
/**
* List all auxiliary states in the symbol.
@@ -49,6 +66,28 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
}
}
+ /**
+ * Invoke symbol as function on inputs.
+ * @param name resulting symbol name
+ * @param symbols provide named symbols
+ * @return the resulting symbol
+ */
+ def apply(name: String, symbols: Map[String, Symbol]): Symbol = {
+ val s = clone()
+ s.compose(name, symbols)
+ s
+ }
+
+ /**
+ * Get a debug string.
+ * @return Debug string of the symbol.
+ */
+ def debugStr: String = {
+ val str = new RefString
+ checkCall(_LIB.mxSymbolPrint(handle, str))
+ str.value
+ }
+
// Set the attribute of the symbol.
private def setAttr(attr: Map[String, String]): Unit = {
attr.foreach { case (key, value) =>
@@ -59,6 +98,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
/**
* Compose symbol on inputs.
* This call mutates the current symbol.
+ * @param name resulting symbol name
* @param symbols provide positional arguments
* @return the resulting symbol
*/
@@ -70,7 +110,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
private def compose(name: String, symbols: Map[String, Symbol]): Unit = {
val keys = symbols.keys.toArray
val args = symbols.values.map(_.handle).toArray
- checkCall(_LIB.mxSymbolCompose(handle, name, null, args))
+ checkCall(_LIB.mxSymbolCompose(handle, name, keys, args))
}
}
@@ -88,10 +128,38 @@ object Symbol {
val handle = new SymbolHandleRef
checkCall(_LIB.mxSymbolCreateVariable(name, handle))
val sym = new Symbol(handle.value)
- sym.setAttr(AttrScope.current.get(attr))
+ sym.setAttr(AttrScope.current.get(Option(attr)))
sym
}
+ def FullyConnected: Map[String, Any] => Symbol = {
+ FullyConnected(null)
+ }
+
+ def FullyConnected(attr: Map[String, String]): Map[String, Any] => Symbol = {
+ createNoCheck("FullyConnected", attr)
+ }
+
+ def Activation: Map[String, Any] => Symbol = {
+ Activation(null)
+ }
+
+ def Activation(attr: Map[String, String]): Map[String, Any] => Symbol = {
+ createNoCheck("Activation", attr)
+ }
+
+ /**
+ * Create a symbol that groups symbols together.
+ * @param symbols List of symbols to be grouped.
+ * @return The created group symbol.
+ */
+ def Group(symbols: Symbol*): Symbol = {
+ val ihandles = symbols.map(_.handle).toArray
+ val handle = new SymbolHandleRef
+ checkCall(_LIB.mxSymbolCreateGroup(ihandles, handle))
+ new Symbol(handle.value)
+ }
+
// List and add all the atomic symbol functions to current module.
private def initSymbolModule(): Map[String, SymbolFunction] = {
val symbolList = ListBuffer.empty[SymbolHandle]
@@ -120,30 +188,31 @@ object Symbol {
/**
* Activation Operator of Neural Net.
* The parameters listed below can be passed in as keyword arguments.
- * @param name Name of the resulting symbol.
- * // TODO
+ * @param symbols Symbol parameters passed to create the resulting symbol
+ * @param paramKwargs Key-value parameters passed to create the resulting symbol
+ * @param attr Attributes set to the resulting symbol
* @return the resulting symbol
*/
- private def creator(operator: String,
- name: String,
- attr: Map[String, String],
- paramKwargs: Map[String, String],
- symbols: Symbol*): Symbol = {
+ def create(operator: String,
+ symbols: Array[Symbol],
+ paramKwargs: Map[String, String],
+ attr: Map[String, String]): Symbol = {
val function = functions(operator)
require(function != null, s"invalid operator name $operator")
+ val params = if (paramKwargs == null) Map.empty[String, String] else paramKwargs
val addkeyVarNumArgs = (function.keyVarNumArgs != null
&& !function.keyVarNumArgs.isEmpty
- && !paramKwargs.contains(function.keyVarNumArgs))
+ && !params.contains(function.keyVarNumArgs))
val paramKeys: Array[String] = (
if (addkeyVarNumArgs) Array[String](function.keyVarNumArgs)
else Array.empty[String]
- ) ++ paramKwargs.keys
+ ) ++ (params - "name").keys
val paramVals: Array[String] = (
if (addkeyVarNumArgs) Array[String](symbols.length.toString)
else Array.empty[String]
- ) ++ paramKwargs.values
+ ) ++ (params - "name").values
// create atomic symbol
val symHandle = new SymbolHandleRef
@@ -151,48 +220,81 @@ object Symbol {
function.handle, paramKeys, paramVals, symHandle))
val s = new Symbol(symHandle.value)
- val attrAll = AttrScope.current.get(attr)
+ val attrAll = AttrScope.current.get(Option(attr))
s.setAttr(attrAll)
val hint = operator.toLowerCase
- val managedName = NameManager.current.get(name, hint)
- s.compose(managedName, symbols.toArray)
+ val managedName = NameManager.current.get(params.get("name"), hint)
+ s.compose(managedName, symbols)
s
}
- private def creator(operator: String, symbols: Symbol*): Symbol = {
- creator(operator, null, null, Map.empty[String, String], symbols:_*)
+ def create(operator: String, symbols: Symbol*): Symbol = {
+ create(operator, symbols.toArray, null, null)
}
- private def creator(operator: String,
- name: String,
- attr: Map[String, String],
- paramKwargs: Map[String, String],
- symbols: Map[String, Symbol]): Symbol = {
+ /**
+ * Activation Operator of Neural Net.
+ * The parameters listed below can be passed in as keyword arguments.
+ * @param symbols Named symbol parameters passed to create the resulting symbol
+ * @param paramKwargs Key-value parameters passed to create the resulting symbol
+ * @param attr Attributes set to the resulting symbol
+ * @return the resulting symbol
+ */
+ private def create(operator: String,
+ symbols: Map[String, Symbol],
+ paramKwargs: Map[String, String],
+ attr: Map[String, String]): Symbol = {
val function = functions(operator)
require(function != null, s"invalid operator name $operator")
require(function.keyVarNumArgs == null || function.keyVarNumArgs.isEmpty,
"This function support variable length of Symbol arguments.\n" +
"Please pass all the input Symbols via positional arguments instead of keyword arguments.")
- val paramKeys = paramKwargs.keys.toArray
- val paramVals = paramKwargs.values.toArray
+ val paramKeys =
+ if (paramKwargs == null) Array.empty[String]
+ else (paramKwargs - "name").keys.toArray
+ val paramVals =
+ if (paramKwargs == null) Array.empty[String]
+ else (paramKwargs - "name").values.toArray
val symHandle = new SymbolHandleRef
checkCall(_LIB.mxSymbolCreateAtomicSymbol(
function.handle, paramKeys, paramVals, symHandle))
val s = new Symbol(symHandle.value)
- val attrAll = AttrScope.current.get(attr)
+ val attrAll = AttrScope.current.get(Option(attr))
s.setAttr(attrAll)
val hint = operator.toLowerCase
- val managedName = NameManager.current.get(name, hint)
+ val managedName = NameManager.current.get(paramKwargs.get("name"), hint)
s.compose(managedName, symbols)
s
}
- private def creator(operator: String, symbols: Map[String, Symbol]): Symbol = {
- creator(operator, null, null, Map.empty[String, String], symbols)
+ def create(operator: String, symbols: Map[String, Symbol]): Symbol = {
+ create(operator, symbols, null, null)
}
+ def create(operator: String,
+ symbols: Map[String, Symbol],
+ paramKwargs: Map[String, String]): Symbol = {
+ create(operator, symbols, paramKwargs, null)
+ }
+
+ // a more friendly interface for creating symbols
+ // all values except symbols in kwargs will be cast to String using its toString() method
+ def createNoCheck(operator: String, attr: Map[String, String] = null)(
+ kwargs: Map[String, Any]): Symbol = {
+ val symbolArgs = kwargs.filter { case (key, value) =>
+ value.isInstanceOf[Symbol]
+ }.map { case (key, value) =>
+ (key, value.asInstanceOf[Symbol])
+ }
+ val strArgs = kwargs.filter { case (key, value) =>
+ !value.isInstanceOf[Symbol]
+ }.map { case (key, value) =>
+ (key, value.toString)
+ }
+ create(operator, symbolArgs, strArgs, attr)
+ }
}
private case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String)
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala
index 6a3b02a23d60..edc6ace47444 100644
--- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/SymbolSuite.scala
@@ -3,10 +3,26 @@ package ml.dmlc.mxnet
import org.scalatest.{BeforeAndAfterAll, FunSuite}
class SymbolSuite extends FunSuite with BeforeAndAfterAll {
- test("plus") {
- val sym1 = Symbol.Variable("data1")
- val sym2 = Symbol.Variable("data2")
- val symPlus = sym1 + sym2
- // TODO: check result
+ test("symbol compose") {
+ val data = Symbol.Variable("data")
+
+ var net1 = Symbol.FullyConnected(Map("data" -> data, "name" -> "fc1", "num_hidden" -> 10))
+ net1 = Symbol.FullyConnected(Map("data" -> net1, "name" -> "fc2", "num_hidden" -> 100))
+ assert(net1.listArguments() ===
+ Array("data", "fc1_weight", "fc1_bias", "fc2_weight", "fc2_bias"))
+
+ var net2 = Symbol.FullyConnected(Map("name" -> "fc3", "num_hidden" -> 10))
+ net2 = Symbol.Activation(Map("data" -> net2, "act_type" -> "relu"))
+ net2 = Symbol.FullyConnected(Map("data" -> net2, "name" -> "fc4", "num_hidden" -> 20))
+ // scalastyle:off println
+ println(s"net2 debug info:\n${net2.debugStr}")
+ // scalastyle:on println
+
+ val composed = net2(name = "composed", Map("fc3_data" -> net1))
+ // scalastyle:off println
+ println(s"composed debug info:\n${composed.debugStr}")
+ // scalastyle:on println
+ val multiOut = Symbol.Group(composed, net1)
+ assert(multiOut.listOutputs().length === 2)
}
}
diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
index 508f26f8378c..f6950d097aa6 100644
--- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
@@ -214,6 +214,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySyncCopyFromCPU
return ret;
}
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree
+ (JNIEnv * env, jobject obj, jobject ndArrayHandle) {
+ return MXNDArrayFree((NDArrayHandle) getLongField(env, ndArrayHandle));
+}
+
// The related c api MXKVStoreSetUpdater function takes a c function pointer as its parameter,
// while we write java functions here in scala-package.
// Thus we have to wrap the function in a java object, and run env->CallVoidMethod(obj) once updater is invoked,
@@ -484,12 +489,6 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env
return rtstr;
}
-JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree(JNIEnv * env, jobject obj, jobject ndArrayHandle) {
- // TODO
- puts("Free ndarray called");
- return 0;
-}
-
//IO funcs
JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters
(JNIEnv * env, jobject obj, jobject creators) {
@@ -828,3 +827,64 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolGetAttr
setIntField(env, successRef, success);
return ret;
}
+
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListArguments
+ (JNIEnv *env, jobject obj, jlong symbolPtr, jobject arguments) {
+ mx_uint outSize;
+ const char **outStrArray;
+ int ret = MXSymbolListArguments((SymbolHandle) symbolPtr, &outSize, &outStrArray);
+
+ jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
+ jmethodID arrayAppend = env->GetMethodID(arrayClass,
+ "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
+ for (int i = 0; i < outSize; i++) {
+ jstring argument = env->NewStringUTF(outStrArray[i]);
+ env->CallObjectMethod(arguments, arrayAppend, argument);
+ }
+
+ return ret;
+}
+
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolListOutputs
+ (JNIEnv *env, jobject obj, jlong symbolPtr, jobject outputs) {
+ mx_uint outSize;
+ const char **outStrArray;
+ int ret = MXSymbolListOutputs((SymbolHandle) symbolPtr, &outSize, &outStrArray);
+
+ jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
+ jmethodID arrayAppend = env->GetMethodID(arrayClass,
+ "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
+ for (int i = 0; i < outSize; i++) {
+ jstring output = env->NewStringUTF(outStrArray[i]);
+ env->CallObjectMethod(outputs, arrayAppend, output);
+ }
+
+ return ret;
+}
+
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCopy
+ (JNIEnv *env, jobject obj, jlong symbolPtr, jobject clonedSymbolRef) {
+ SymbolHandle clonedSymbol;
+ int ret = MXSymbolCopy((SymbolHandle) symbolPtr, &clonedSymbol);
+ setLongField(env, clonedSymbolRef, (long)clonedSymbol);
+ return ret;
+}
+
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolCreateGroup
+ (JNIEnv *env, jobject obj, jlongArray jsymbols, jobject out) {
+ int numSymbols = env->GetArrayLength(jsymbols);
+ SymbolHandle handle;
+ jlong *symbols = env->GetLongArrayElements(jsymbols, NULL);
+ int ret = MXSymbolCreateGroup(numSymbols, (SymbolHandle *)symbols, &handle);
+ env->ReleaseLongArrayElements(jsymbols, symbols, 0);
+ setLongField(env, out, (long)handle);
+ return ret;
+}
+
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxSymbolPrint
+ (JNIEnv *env, jobject obj, jlong symbolPtr, jobject out) {
+ const char *outStr;
+ int ret = MXSymbolPrint((SymbolHandle) symbolPtr, &outStr);
+ setStringField(env, out, outStr);
+ return ret;
+}
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index 524c9a52eb0f..3e2a9a2494be 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -202,6 +202,11 @@
scala-library
${scala.version}
+
+ org.scala-lang
+ scala-reflect
+ ${scala.version}
+
commons-codec
commons-codec