From 884f14be8d4d5f818dd76b7c8564d7f6185336c5 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 24 Dec 2015 23:17:51 +0800 Subject: [PATCH 01/13] add Model.scala with createKVStore & initializeKVStore --- .../src/main/scala/ml/dmlc/mxnet/Model.scala | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala new file mode 100644 index 000000000000..819cc70f417d --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala @@ -0,0 +1,62 @@ +package ml.dmlc.mxnet + +import org.slf4j.LoggerFactory + +object Model { + private val logger = LoggerFactory.getLogger(classOf[Model]) + /** + * Create kvstore + * This function select and create a proper kvstore given the kvstore type + * @param kvStore KVStore type + * @param numDevice The number of devices + * @param maxSize max size of the kvstore + * @return Option of created [[KVStore]] and whether or not update weight on it + */ + private def createKVStore(kvStore: String, numDevice: Int, maxSize: Int): (Option[KVStore], Boolean) = { + if (numDevice == 1 && !kvStore.contains("dist")) { + // no need to use kv for single device and single machine + (None, false) + } else { + var kvType = kvStore + if (kvType == "local") { + //automatically select a proper local + kvType = + if (maxSize < 1024 * 1024 * 16) { + "local_update_cpu" + } else { + "local_allreduce_cpu" + } + logger.info(s"Auto - select kvstore type = $kvType") + } + (Option(KVStore.create(kvType)), !kvType.contains("local_allreduce")) + } + } + + /** + * Create a kvstore (wrap it with Option, None if given kvStore == null) + * @param kvStore + * @return Option of created [[KVStore]] and whether or not update weight on it + */ + private def createKVStore(kvStore: KVStore): (Option[KVStore], Boolean) = { + (Option(kvStore), kvStore != null && !kvStore.`type`.contains("local_allreduce")) + } + + // Initialize kvstore + private def initializeKVStore(kvStore: KVStore, + paramArrays: Array[NDArray], + argParams: Map[String, NDArray], + paramNames: Array[String], + updateOnKVStore: Boolean): Unit = { + require(paramArrays.length == paramNames.length) + for (idx <- 0 until paramArrays.length) { + val paramOnDevs = paramArrays(idx) + kvStore.init(idx, argParams(paramNames(idx))) + if (updateOnKVStore) { + kvStore.pull(idx, paramOnDevs, -idx) + } + } + } +} + +class Model { +} From 2ce0e511cfa68a24550993709af1e546ca7d744d Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Fri, 25 Dec 2015 20:45:20 +0800 Subject: [PATCH 02/13] some Model functions, add (key, array-of-values) pull & push to KVStore --- .../main/scala/ml/dmlc/mxnet/KVStore.scala | 18 ++++++++ .../src/main/scala/ml/dmlc/mxnet/Model.scala | 44 ++++++++++++++++++- 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala index 536123c57c52..3882383f70da 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala @@ -71,6 +71,15 @@ class KVStore(private val handle: KVStoreHandle) { push(Array(key), Array(value), priority) } + def push(key: Int, values: Array[NDArray], priority: Int): Unit = { + val keys = Array.fill(values.length)(key) + push(keys, values, priority) + } + + def push(key: Int, values: Array[NDArray]): Unit = { + push(key, values, 0) + } + /** * Pull a single value or a sequence of values from the store. * @@ -98,6 +107,15 @@ class KVStore(private val handle: KVStoreHandle) { pull(Array(key), Array(out), priority) } + def pull(key: Int, outs: Array[NDArray], priority: Int): Unit = { + val keys = Array.fill(outs.length)(key) + pull(keys, outs, priority) + } + + def pull(key: Int, outs: Array[NDArray]): Unit = { + pull(key, outs, 0) + } + // Get the type of this kvstore def `type`: String = { val kvType = new RefString diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala index 819cc70f417d..4277bc08c654 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala @@ -2,6 +2,10 @@ package ml.dmlc.mxnet import org.slf4j.LoggerFactory +/** + * Describe the model flow + * @author Yizhi Liu + */ object Model { private val logger = LoggerFactory.getLogger(classOf[Model]) /** @@ -43,7 +47,7 @@ object Model { // Initialize kvstore private def initializeKVStore(kvStore: KVStore, - paramArrays: Array[NDArray], + paramArrays: Array[Array[NDArray]], argParams: Map[String, NDArray], paramNames: Array[String], updateOnKVStore: Boolean): Unit = { @@ -56,6 +60,44 @@ object Model { } } } + + // Perform update of param_arrays from grad_arrays on kvstore + private def updateParamsOnKVStore(paramArrays: Array[Array[NDArray]], + gradArrays: Array[Array[NDArray]], + kvStore: KVStore): Unit = { + (paramArrays zip gradArrays).zipWithIndex.foreach { case ((argList, gradList), index) => + if (gradList != null) { + // push gradient, priority is negative index + kvStore.push(index, gradList, -index) + // pull back the weights + kvStore.pull(index, argList, -index) + } + } + } + + // Perform update of param_arrays from grad_arrays not on kvstore + private def updateParams(paramArrays: Array[Array[NDArray]], + gradArrays: Array[Array[NDArray]], + updater: MXKVStoreUpdater, + numDevice: Int, + kvStore: Option[KVStore] = None) { + (paramArrays zip gradArrays).zipWithIndex.foreach { case ((argList, gradList), index) => + if (gradList != null) { + kvStore.foreach(kv => { + // push gradient, priority is negative index + kv.push(index, gradList, -index) + // pull back the sum gradients, to the same locations. + kv.pull(index, gradList, -index) + }) + (argList zip gradList).zipWithIndex.foreach { case ((w: NDArray, g: NDArray), k: Int) => + // faked an index here, to make optimizer create diff + // state for the same index but on diff devs, + // (copy from python package) TODO(mli) use a better solution latter + updater.update(index * numDevice + k, g, w) + } + } + } + } } class Model { From 9e3fe872e09ff25f9153323ea421358d68377818 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 26 Dec 2015 00:15:52 +0800 Subject: [PATCH 03/13] Executor::forward and other helper functions --- .../src/main/scala/ml/dmlc/mxnet/Base.scala | 1 + .../main/scala/ml/dmlc/mxnet/Executor.scala | 75 +++++++++++++++++-- .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 8 +- .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 5 ++ .../main/native/ml_dmlc_mxnet_native_c_api.cc | 36 +++++++++ 5 files changed, 119 insertions(+), 6 deletions(-) create mode 100644 scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index 85047c09acaf..faa1cec9bc27 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -16,6 +16,7 @@ object Base { type NDArrayHandle = RefLong type FunctionHandle = RefLong type KVStoreHandle = RefLong + type ExecutorHandle = RefLong System.loadLibrary("mxnet-scala") val _LIB = new LibInfo diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala index 4eb51f74cc5f..83265e9b3368 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala @@ -1,11 +1,76 @@ package ml.dmlc.mxnet +import ml.dmlc.mxnet.Base._ + +import scala.collection.mutable.ArrayBuffer + +object Executor { + // Get the dictionary given name and ndarray pairs. + private def getDict(names: Array[String], ndarrays: Array[NDArray]): Map[String, NDArray] = { + require(names.toSet.size == names.length, "Duplicate names detected") + (names zip ndarrays).toMap + } +} + /** - * Created by yuantang on 12/23/15. + * Symbolic Executor component of MXNet + * @author Yizhi Liu + * + * Constructor: used Symbol.bind and Symbol.simple_bind instead. + * @param handle ExecutorHandle generated by calling Bind + * @param symbol + * @see Symbol.bind : to create executor */ -abstract class Executor(var argArrays: Array[NDArray]) { - def forward - def backward - def setMonitorCallback(callback: (String, NDArray) => Any) +class Executor(val handle: ExecutorHandle, val symbol: Symbol) { + var argArrays: Array = _ + var gradArrays: Array = _ + var auxArrays: Array = _ + var outputs: Array[NDArray] = getOutputs + var argDict: Map[String, NDArray] = _ + /* + self._grad_dict = None + self._aux_dict = None + self._monitor_callback = None + */ + + override def finalize(): Unit = { + checkCall(_LIB.mxExecutorFree(handle)) + } + + /** + * list all the output ndarray + * @return A list of ndarray binded to the heads of executor. + */ + private def getOutputs: Array[NDArray] = { + val ndHandles = ArrayBuffer[NDArrayHandle]() + checkCall(_LIB.mxExecutorOutputs(handle, ndHandles)) + ndHandles.toArray.map(new NDArray(_)) + } + + /** + * Calculate the outputs specified by the binded symbol. + * @param isTrain whether this forward is for evaluation purpose. + * @param kwargs Additional specification of input arguments. + */ + def forward(isTrain: Boolean, kwargs: (String, NDArray)*): Unit = { + kwargs.foreach { case (name, array) => + require(argDict.contains(name), s"Unknown argument $name") + array.copyTo(argDict(name)) + } + checkCall(_LIB.mxExecutorForward(handle, if (isTrain) 1 else 0)) + } + + def forward(): Unit = { + forward(isTrain = false) + } + /** + * Do backward pass to get the gradient of arguments. + * @param outGrads + * Gradient on the outputs to be propagated back. + * This parameter is only needed when bind is called + * on outputs that are not a loss function. + */ + def backward(outGrads: Array[NDArray]):Unit = { + } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 2592394b8550..2267b53b141f 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -4,7 +4,10 @@ import ml.dmlc.mxnet.Base._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} -// JNI functions +/** + * JNI functions + * @author Yizhi Liu + */ class LibInfo { @native def mxNDArrayFree(handle: NDArrayHandle): Int @native def mxGetLastError(): String @@ -81,4 +84,7 @@ class LibInfo { @native def mxKVStoreBarrier(handle: KVStoreHandle): Int @native def mxKVStoreGetGroupSize(handle: KVStoreHandle, size: RefInt): Int @native def mxKVStoreGetRank(handle: KVStoreHandle, size: RefInt): Int + @native def mxExecutorOutputs(handle: ExecutorHandle, outputs: ArrayBuffer[NDArrayHandle]): Int + @native def mxExecutorFree(handle: ExecutorHandle): Int + @native def mxExecutorForward(handle: ExecutorHandle, isTrain: Int): Int } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala new file mode 100644 index 000000000000..7835fb557063 --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -0,0 +1,5 @@ +package ml.dmlc.mxnet + +class Symbol { + +} diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index e1bb66d45858..c15cb3b5e009 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -386,6 +386,42 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetRank return ret; } +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorOutputs + (JNIEnv *env, jobject obj, jobject executorHandle, jobject outputs) { + + jlong executorPtr = getLongField(env, executorHandle); + mx_uint outSize; + NDArrayHandle *out; + int ret = MXExecutorOutputs((ExecutorHandle)executorPtr, &outSize, &out); + + // Base.ExecutorHandle.constructor + jclass ndArrayClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); + jmethodID ndArrayConstructor = env->GetMethodID(fhClass,"","(J)V"); + + // fill java outputs + jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer"); + jmethodID arrayAppend = env->GetMethodID(arrayClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;"); + for (int i = 0; i < outSize; ++i) { + jobject ndArray = env->NewObject(ndArrayClass, ndArrayConstructor, (long)out[i]); + env->CallObjectMethod(outputs, arrayAppend, ndArray); + } + + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorFree + (JNIEnv * env, jobject obj, jobject handle) { + jlong ptr = getLongField(env, handle); + return MXExecutorFree((ExecutorHandle) ptr); +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorForward + (JNIEnv * env, jobject obj, jobject handle, jint isTrain) { + jlong ptr = getLongField(env, handle); + return MXExecutorForward((ExecutorHandle)ptr, (int)isTrain); +} + JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) { char *tmpstr = "MXNetError"; jstring rtstr = env->NewStringUTF(tmpstr); From ffd819c816443bbc6e15857b997fa6d96fd49289 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 26 Dec 2015 15:09:15 +0800 Subject: [PATCH 04/13] Executor class functions --- .../main/scala/ml/dmlc/mxnet/Executor.scala | 102 +++++++++++++++++- .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 6 ++ .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 16 +++ .../native/src/main/native/jni_helper_func.h | 6 ++ .../main/native/ml_dmlc_mxnet_native_c_api.cc | 21 +++- 5 files changed, 145 insertions(+), 6 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala index 83265e9b3368..ae40e68c5d30 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala @@ -22,14 +22,14 @@ object Executor { * @see Symbol.bind : to create executor */ class Executor(val handle: ExecutorHandle, val symbol: Symbol) { - var argArrays: Array = _ - var gradArrays: Array = _ - var auxArrays: Array = _ + var argArrays: Array[NDArray] = null + var gradArrays: Array[NDArray] = null + var auxArrays: Array[NDArray] = null var outputs: Array[NDArray] = getOutputs - var argDict: Map[String, NDArray] = _ + var _argDict: Map[String, NDArray] = null + var _auxDict: Map[String, NDArray] = null /* self._grad_dict = None - self._aux_dict = None self._monitor_callback = None */ @@ -72,5 +72,97 @@ class Executor(val handle: ExecutorHandle, val symbol: Symbol) { * on outputs that are not a loss function. */ def backward(outGrads: Array[NDArray]):Unit = { + require(outGrads != null) + val ndArrayPtrs = outGrads.map(_.handle.value) + checkCall(_LIB.mxExecutorBackward(handle, outGrads.length, ndArrayPtrs)) + } + + def backward(outGrad: NDArray): Unit = { + require(outGrad != null) + backward(Array(outGrad)) + } + + def backward(): Unit = { + backward(Array.empty[NDArray]) + } + + /** + * Install callback. + * callback Takes a string and an NDArrayHandle. + */ + def setMonitorCallback(callback: (String, NDArrayHandle) => Unit): Unit = ??? + + /** + * Get dictionary representation of argument arrrays. + * @return The dictionary that maps name of arguments to NDArrays. + * @throws IllegalArgumentException if there are duplicated names in the arguments. + */ + def argDict: Map[String, NDArray] = { + if (_argDict == null) { + _argDict = Executor.getDict(symbol.listArguments(), argArrays) + } + _argDict + } + + /** + * Get dictionary representation of auxiliary states arrays. + * @return The dictionary that maps name of auxiliary states to NDArrays. + * @throws IllegalArgumentException if there are duplicated names in the auxiliary states. + */ + def auxDict: Map[String, NDArray] = { + if (_auxDict == null) { + _auxDict = Executor.getDict( + symbol.listAuxiliaryStates(), auxArrays) + } + _auxDict + } + + /** + * Copy parameters from arg_params, aux_params into executor's internal array. + * @param argParams : dict of name to NDArray of arguments + * @param auxParams : dict of name to NDArray of auxiliary states. + * @param allowExtraParams + * Whether allow extra parameters that are not needed by symbol + * If this is True, no error will be thrown when arg_params or aux_params + * contain extra parameters that is not needed by the executor. + * @throws IllegalArgumentException If there is additional parameters in the dict but allow_extra_params=False + */ + def copyParamsFrom(argParams: Map[String, NDArray], + auxParams: Map[String, NDArray], + allowExtraParams: Boolean = false): Unit = { + argParams.foreach { case (name, array) => + if (argDict.contains(name)) { + array.copyTo(argDict(name)) + } else { + require(allowExtraParams, s"Find name $name that is not in the arguments") + } + } + if (auxParams != null) { + auxParams.foreach { case (name, array) => + if (auxDict.contains(name)) { + array.copyTo(auxDict(name)) + } else { + require(allowExtraParams, s"Find name $name that is not in the auxiliary states") + } + } + } + } + + def copyParamsFrom(argParams: Map[String, NDArray], allowExtraParams: Boolean): Unit = { + copyParamsFrom(argParams, null, allowExtraParams) + } + + def copyParamsFrom(argParams: Map[String, NDArray]): Unit = { + copyParamsFrom(argParams, allowExtraParams = false) + } + + /** + * Get a debug string about internal execution plan. + * @return Debug string of the executor. + */ + def debugStr: String = { + val str = new RefString + checkCall(_LIB.mxExecutorPrint(handle, str)) + str.value } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 2267b53b141f..dea13f362cb5 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -87,4 +87,10 @@ class LibInfo { @native def mxExecutorOutputs(handle: ExecutorHandle, outputs: ArrayBuffer[NDArrayHandle]): Int @native def mxExecutorFree(handle: ExecutorHandle): Int @native def mxExecutorForward(handle: ExecutorHandle, isTrain: Int): Int + @native def mxExecutorBackward(handle: ExecutorHandle, + gradsSize: Int, + // outs ought to be Array[NDArrayHandle], + // we pass ptr address directly for performance consideration + grads: Array[CPtrAddress]): Int + @native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 7835fb557063..3f70f2764c2a 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -1,5 +1,21 @@ package ml.dmlc.mxnet class Symbol { + /** + * List all the arguments in the symbol. + * @return Array of all the arguments. + */ + def listArguments(): Array[String] = ??? + /** + * List all auxiliary states in the symbool. + * @return The names of the auxiliary states. + * Notes + * ----- + * Auxiliary states are special states of symbols that do not corresponds to an argument, + * and do not have gradient. But still be useful for the specific operations. + * A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm. + * Most operators do not have Auxiliary states. + */ + def listAuxiliaryStates(): Array[String] = ??? } diff --git a/scala-package/native/src/main/native/jni_helper_func.h b/scala-package/native/src/main/native/jni_helper_func.h index 0d651adf9a94..43668bc1df9d 100644 --- a/scala-package/native/src/main/native/jni_helper_func.h +++ b/scala-package/native/src/main/native/jni_helper_func.h @@ -21,4 +21,10 @@ void setIntField(JNIEnv *env, jobject obj, jint value) { env->SetIntField(obj, refFid, value); } +void setStringField(JNIEnv *env, jobject obj, const char *value) { + jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefString"); + jfieldID refFid = env->GetFieldID(refClass, "value", "Ljava/lang/String;"); + env->SetObjectField(obj, refFid, env->NewStringUTF(value)); +} + #endif diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index c15cb3b5e009..a63ee281f35b 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -422,6 +422,26 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorForward return MXExecutorForward((ExecutorHandle)ptr, (int)isTrain); } +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorBackward + (JNIEnv * env, jobject obj, jobject handle, jint gradsSize, jlongArray grads) { + jlong executorPtr = getLongField(env, handle); + jlong *gradArr = env->GetLongArrayElements(grads, NULL); + int ret = MXExecutorBackward((ExecutorHandle)executorPtr, + (mx_uint)gradsSize, + (NDArrayHandle *)gradArr) + env->ReleaseLongArrayElements(grads, gradArr, 0); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorPrint + (JNIEnv * env, jobject obj, jobject handle, jobject debugStr) { + jlong ptr = getLongField(env, handle); + const char *retDebugStr; + int ret = MXExecutorPrint((ExecutorHandle)handle, &retDebugStr); + setStringField(env, debugStr, retDebugStr); + return ret; +} + JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) { char *tmpstr = "MXNetError"; jstring rtstr = env->NewStringUTF(tmpstr); @@ -433,4 +453,3 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree(JNIEnv * env, jo puts("Free ndarray called"); return 0; } - From 781a0a72a7b5ad06d259de643b0a6198a676cceb Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 26 Dec 2015 16:19:54 +0800 Subject: [PATCH 05/13] add Executor::splitInputSlice, fix build break --- .../main/scala/ml/dmlc/mxnet/Executor.scala | 40 +++++++++++++++++++ .../main/native/ml_dmlc_mxnet_native_c_api.cc | 4 +- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala index ae40e68c5d30..d8d88163abaa 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala @@ -10,6 +10,46 @@ object Executor { require(names.toSet.size == names.length, "Duplicate names detected") (names zip ndarrays).toMap } + + /** + * Get input slice from the input shape. + Parameters + ---------- + * batch_size : int + * The number of samples in a mini-batch. + * work_load_list : list of float or int, optional + * The list of work load for different devices, + * in the same order as ctx + * Returns + * ------- + * slices : list of slice + * The split slices to get a specific slice. + * Raises + * ------ + * ValueError + * If there are two many splits such that some slice can be empty. + */ + private def splitInputSlice[@specialized(Int, Float, Double) V] + (batchSize: Int, workLoadList: Array[V]) + (implicit num: Numeric[V]): Array[(Int, Int)] = { + val totalWorkLoad = workLoadList.sum.asInstanceOf[Float] + val batchNumList = workLoadList.map(workLoad => + math.round(workLoad.asInstanceOf[Float] * batchSize / totalWorkLoad)) + val batchNumSum = batchNumList.sum + if (batchNumSum < batchSize) { + batchNumList(batchNumList.length-1) += batchSize - batchNumSum + } + + val slices = ArrayBuffer.empty[(Int, Int)] + var end = 0 + batchNumList.foreach(batchNum => { + val begin = math.min(end, batchSize) + end = math.min(begin + batchNum, batchSize) + require(begin < end, "Too many slices such that some splits are empty") + slices.append((begin, end)) + }) + slices.toArray + } } /** diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index a63ee281f35b..8978fe0545bf 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -396,7 +396,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorOutputs // Base.ExecutorHandle.constructor jclass ndArrayClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jmethodID ndArrayConstructor = env->GetMethodID(fhClass,"","(J)V"); + jmethodID ndArrayConstructor = env->GetMethodID(ndArrayClass,"","(J)V"); // fill java outputs jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer"); @@ -428,7 +428,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorBackward jlong *gradArr = env->GetLongArrayElements(grads, NULL); int ret = MXExecutorBackward((ExecutorHandle)executorPtr, (mx_uint)gradsSize, - (NDArrayHandle *)gradArr) + (NDArrayHandle *)gradArr); env->ReleaseLongArrayElements(grads, gradArr, 0); return ret; } From e720e70824bb54ad145ced4ef310fa5f8b73e62f Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 26 Dec 2015 23:14:13 +0800 Subject: [PATCH 06/13] Executor static functions --- .../main/scala/ml/dmlc/mxnet/Executor.scala | 56 +++++++++++++------ 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala index d8d88163abaa..c6f9a5d15924 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala @@ -13,21 +13,10 @@ object Executor { /** * Get input slice from the input shape. - Parameters - ---------- - * batch_size : int - * The number of samples in a mini-batch. - * work_load_list : list of float or int, optional - * The list of work load for different devices, - * in the same order as ctx - * Returns - * ------- - * slices : list of slice - * The split slices to get a specific slice. - * Raises - * ------ - * ValueError - * If there are two many splits such that some slice can be empty. + * @param batchSize The number of samples in a mini-batch. + * @param workLoadList The list of work load for different devices, in the same order as ctx + * @return The split slices to get a specific slice. + * @throws IllegalArgumentException If there are two many splits such that some slice can be empty. */ private def splitInputSlice[@specialized(Int, Float, Double) V] (batchSize: Int, workLoadList: Array[V]) @@ -50,6 +39,40 @@ object Executor { }) slices.toArray } + + /** + * Check the argument names of symbol. + * This function checks the duplication of arguments in Symbol. + * The check is done for feedforward net for now. + * @param symbol The network configuration + */ + private def checkArguments(symbol: Symbol): Unit = { + val argNames = symbol.listArguments() + require(argNames.toSet.size == argNames.length, + "Find duplicated argument name," + + "please make the weight name non-duplicated(using name arguments)," + + s"arguments are $argNames") + + val auxNames = symbol.listAuxiliaryStates() + require(auxNames.toSet.size == auxNames.length, + "Find duplicated auxiliary param name," + + "please make the weight name non-duplicated(using name arguments)," + + s"arguments are $auxNames") + } + + // Load a list of arrays into a list of arrays + private def loadGeneral(data: Array[NDArray], targets: Array[NDArray]): Unit = { + (data zip targets).foreach { case (dSrc, dTarget) => + dSrc.copyTo(dTarget) + } + } + + // Load a list of arrays into a list of arrays specified by slices + private def loadGeneral(data: Array[NDArray], targets: Array[(Int, Int, NDArray)]): Unit = { + (data zip targets).foreach { case (dSrc, (start, end, dTarget)) => + dSrc.slice(start, end).copyTo(dTarget) + } + } } /** @@ -128,7 +151,8 @@ class Executor(val handle: ExecutorHandle, val symbol: Symbol) { /** * Install callback. - * callback Takes a string and an NDArrayHandle. + * TODO: make callback a java class to make it java-friendly + * @param callback Takes a string and an NDArrayHandle. */ def setMonitorCallback(callback: (String, NDArrayHandle) => Unit): Unit = ??? From 7f89e07466c0d3f419b4704212a465ae60c7827c Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 26 Dec 2015 23:38:06 +0800 Subject: [PATCH 07/13] merge Monitor and fix some type mismatch --- .../core/src/main/scala/ml/dmlc/mxnet/Monitor.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala index 1c9c1edbabc1..4879dee4f0d2 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala @@ -1,5 +1,6 @@ package ml.dmlc.mxnet +import ml.dmlc.mxnet.Base.NDArrayHandle import org.slf4j.LoggerFactory import scala.collection.mutable @@ -27,10 +28,11 @@ class Monitor(protected val interval: Int, protected var statFunc: (NDArray) => private var step: Int = 0 private var exes = new mutable.Queue[Executor] - protected val statHelper = (name: String, arr: NDArray) => { + protected def statHelper(name: String, arr: NDArrayHandle): Unit = { if (activated) { // TODO: more details here - queue ++= List((step, name, statFunc(arr))) + val array = new NDArray(arr, writable=false) + queue ++= List((step, name, statFunc(array))) } } From ed4077ddfab64f84a60af88877d492c85fb26edd Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 27 Dec 2015 15:47:44 +0800 Subject: [PATCH 08/13] change Monitor callback function to trait to make it java-friendly --- .../src/main/scala/ml/dmlc/mxnet/Executor.scala | 2 +- .../src/main/scala/ml/dmlc/mxnet/Monitor.scala | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala index c6f9a5d15924..239cec4bbb20 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala @@ -154,7 +154,7 @@ class Executor(val handle: ExecutorHandle, val symbol: Symbol) { * TODO: make callback a java class to make it java-friendly * @param callback Takes a string and an NDArrayHandle. */ - def setMonitorCallback(callback: (String, NDArrayHandle) => Unit): Unit = ??? + def setMonitorCallback(callback: MXMonitorCallback): Unit = ??? /** * Get dictionary representation of argument arrrays. diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala index 4879dee4f0d2..cb663f6a11f0 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala @@ -28,11 +28,13 @@ class Monitor(protected val interval: Int, protected var statFunc: (NDArray) => private var step: Int = 0 private var exes = new mutable.Queue[Executor] - protected def statHelper(name: String, arr: NDArrayHandle): Unit = { - if (activated) { - // TODO: more details here - val array = new NDArray(arr, writable=false) - queue ++= List((step, name, statFunc(array))) + val statHelper: MXMonitorCallback = new MXMonitorCallback { + override def invoke(name: String, arr: NDArrayHandle): Unit = { + if (activated) { + // TODO: more details here + val array = new NDArray(arr, writable=false) + queue ++= List((step, name, statFunc(array))) + } } } @@ -114,3 +116,7 @@ class Monitor(protected val interval: Int, protected var statFunc: (NDArray) => } } + +trait MXMonitorCallback { + def invoke(name: String, arr: NDArrayHandle): Unit +} From 7cc514a9374fe1b089fcd1c3ef53122b098e7c68 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 27 Dec 2015 18:00:12 +0800 Subject: [PATCH 09/13] complete Monitor.scala, add NDArray::norm --- .../main/scala/ml/dmlc/mxnet/Executor.scala | 21 +++-- .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 1 + .../main/scala/ml/dmlc/mxnet/Monitor.scala | 79 +++++++++---------- .../main/scala/ml/dmlc/mxnet/NDArray.scala | 9 +++ .../scala/ml/dmlc/mxnet/NDArraySuite.scala | 11 ++- .../main/native/ml_dmlc_mxnet_native_c_api.cc | 36 +++++++++ 6 files changed, 102 insertions(+), 55 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala index 239cec4bbb20..cafbae36362c 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala @@ -86,15 +86,12 @@ object Executor { */ class Executor(val handle: ExecutorHandle, val symbol: Symbol) { var argArrays: Array[NDArray] = null - var gradArrays: Array[NDArray] = null - var auxArrays: Array[NDArray] = null - var outputs: Array[NDArray] = getOutputs - var _argDict: Map[String, NDArray] = null - var _auxDict: Map[String, NDArray] = null - /* - self._grad_dict = None - self._monitor_callback = None - */ + protected var gradArrays: Array[NDArray] = null + protected var auxArrays: Array[NDArray] = null + protected var outputs: Array[NDArray] = getOutputs + protected var _argDict: Map[String, NDArray] = null + protected var _auxDict: Map[String, NDArray] = null + protected var monitorCallback: MXMonitorCallback = null override def finalize(): Unit = { checkCall(_LIB.mxExecutorFree(handle)) @@ -151,10 +148,12 @@ class Executor(val handle: ExecutorHandle, val symbol: Symbol) { /** * Install callback. - * TODO: make callback a java class to make it java-friendly * @param callback Takes a string and an NDArrayHandle. */ - def setMonitorCallback(callback: MXMonitorCallback): Unit = ??? + def setMonitorCallback(callback: MXMonitorCallback): Unit = { + monitorCallback = callback + checkCall(_LIB.mxExecutorSetMonitorCallback(handle, monitorCallback)) + } /** * Get dictionary representation of argument arrrays. diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index dea13f362cb5..7706312bfd68 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -93,4 +93,5 @@ class LibInfo { // we pass ptr address directly for performance consideration grads: Array[CPtrAddress]): Int @native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int + @native def mxExecutorSetMonitorCallback(handle: ExecutorHandle, callback: MXMonitorCallback): Int } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala index cb663f6a11f0..5046c8bbd7d2 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala @@ -2,12 +2,13 @@ package ml.dmlc.mxnet import ml.dmlc.mxnet.Base.NDArrayHandle import org.slf4j.LoggerFactory + import scala.collection.mutable /** * Monitor outputs, weights, and gradients for debugging. * - * @author Yuan Tang + * @author Yuan Tang, Yizhi Liu * * @param interval Number of batches between printing. * @param statFunc A function that computes statistics of tensors. @@ -19,8 +20,9 @@ class Monitor(protected val interval: Int, protected var statFunc: (NDArray) => private val logger = LoggerFactory.getLogger(classOf[Monitor]) if (statFunc == null) { - // TODO: more details here - statFunc = (x: NDArray) => x + statFunc = (x: NDArray) => { + NDArray.norm(x) / math.sqrt(x.size.toDouble).toFloat + } } private var activated: Boolean = false @@ -30,88 +32,79 @@ class Monitor(protected val interval: Int, protected var statFunc: (NDArray) => val statHelper: MXMonitorCallback = new MXMonitorCallback { override def invoke(name: String, arr: NDArrayHandle): Unit = { + // wrapper for executor callback if (activated) { - // TODO: more details here val array = new NDArray(arr, writable=false) - queue ++= List((step, name, statFunc(array))) + val elem = (step, name, statFunc(array)) + queue += elem } } } - /** * Install callback to executor. * Supports installing to multiple exes * @param exe the Executor (returned by symbol.bind) to install to. */ - def install(exe: Executor) = { + def install(exe: Executor): Unit = { exe.setMonitorCallback(statHelper) - exes ++= List(exe) + exes += exe } - /** * Start collecting stats for current batch. * Call before forward */ - def tic = { + def tic(): Unit = { if (step % interval == 0) { exes.foreach { exe => - exe.argArrays.foreach {arr => arr.waitToRead()} + exe.argArrays.foreach(_.waitToRead()) } - queue = new mutable.Queue[(Int, String, NDArray)] + queue = new mutable.Queue[(Int, String, NDArray)] activated = true } step += 1 } - /** * End collecting for current batch and return results. * Call after computation of current batch. */ - def toc: mutable.Queue[(Int, String, String)] = { - + def toc(): mutable.Queue[(Int, String, String)] = { if (activated) { exes.foreach { exe => - exe.argArrays.foreach {arr => arr.waitToRead()} + exe.argArrays.foreach(_.waitToRead()) } exes.foreach { exe => - null - // TODO: need to implement Symbol first - /* for name, array in zip(exe._symbol.list_arguments(), exe.arg_arrays): - self.queue.append((self.step, name, self.stat_func(array)))*/ + (exe.symbol.listArguments() zip exe.argArrays).foreach { case (name, array) => + val elem = (step, name, statFunc(array)) + queue += elem + } } - } else { - return new mutable.Queue[(Int, String, String)] - } - - activated = false - - val res = new mutable.Queue[(Int, String, String)] - - queue.foreach { q => - val (n, k, v) = q - if (v.shape.sameElements(Array(1))) { - res ++= List((n, k, v.toScalar.toString)) - } else { - res ++= List((n, k, v.toArray.toString)) + activated = false + val res = new mutable.Queue[(Int, String, String)] + queue.foreach { q => + val (n, k, v) = q + if (v.shape.sameElements(Array(1))) { + res += ((n, k, v.toScalar.toString)) + } else { + res += ((n, k, s"[${v.toArray.mkString(",")}]")) + } } + queue = new mutable.Queue[(Int, String, NDArray)] + res + } else { + new mutable.Queue[(Int, String, String)] } - - queue = new mutable.Queue[(Int, String, NDArray)] - - return res } /** * End collecting and print results */ - def tocPrint = { - val res = toc - res.foreach { re => - val (n, k, v) = re - logger.info(s"Batch: ${n} ${k} ${v}") + def tocPrint(): Unit = { + val res = toc() + res.foreach { case (n, k, v) => + logger.info(s"Batch: $n $k $v") } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index 91a8e526634e..53b2145bc29a 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -259,6 +259,15 @@ object NDArray { NDArray._unaryNDArrayFunction("sqrt", src) } + /** + * Take L2 norm of the src. + * @param src Source input to the function + * @return new [[NDArray]] of shape (1,) on the same device + */ + def norm(src: NDArray): NDArray = { + NDArray._unaryNDArrayFunction("norm", src) + } + /** * Create a new NDArray that copies content from source_array. * @param sourceArr Source data to create NDArray from. diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala index dba751df7916..5fa02ece10bf 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala @@ -1,7 +1,8 @@ package ml.dmlc.mxnet -import org.scalatest.{FunSuite, BeforeAndAfterAll} import ml.dmlc.mxnet.NDArrayConversions._ +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalactic.Tolerance._ class NDArraySuite extends FunSuite with BeforeAndAfterAll { test("to java array") { @@ -96,4 +97,12 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { ndarray.set(Array(0f, 1f, 4f, 9f)) assert(NDArray.sqrt(ndarray).toArray === Array(0f, 1f, 2f, 3f)) } + + test("norm") { + val ndarray = NDArray.empty(3, 1) + ndarray.set(Array(1f, 2f, 3f)) + val normed = NDArray.norm(ndarray) + assert(normed.shape === Array(1)) + assert(normed.toScalar === math.sqrt(14.0).toFloat +- 1e-3f) + } } diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 8978fe0545bf..c387f24e1350 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -3,6 +3,7 @@ #include #include #include +#include #include #include "jni_helper_func.h" @@ -442,6 +443,41 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorPrint return ret; } +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorSetMonitorCallback + (JNIEnv *env, jobject obj, jobject handle, jobject callbackFuncObj) { + jlong executorPtr = getLongField(env, handle); + jobject callbackFuncObjGlb = env->NewGlobalRef(callbackFuncObj); + std::function callback + = [env, callbackFuncObjGlb](const char *name, NDArrayHandle array) { + // find java callback method + jclass callbackClass = env->GetObjectClass(callbackFuncObjGlb); + jmethodID invokeFunc = env->GetMethodID(callbackClass, + "invoke", "(Ljava/lang/String;Lml/dmlc/mxnet/Base$RefLong;)V"); + + jstring jname = env->NewStringUTF(name); + // ndArray handle + jclass ndHandleClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); + jmethodID ndHandleCont = env->GetMethodID(ndHandleClass,"","(J)V"); + jobject jNDArrayHandle = env->NewObject(ndHandleClass, ndHandleCont, (long)array); + + env->CallVoidMethod(callbackFuncObjGlb, invokeFunc, jname, jNDArrayHandle); + env->DeleteGlobalRef(callbackFuncObjGlb); + }; + /* TODO: we need to modify Executor::SetMonitorCallback, make it take std::function as param + try { + mxnet::Executor *exec = static_cast((ExecutorHandle)executorPtr); + exec->SetMonitorCallback(callback); + } catch(dmlc::Error &except) { + // It'll be too complicated to set & get mx error in jni code. + // thus simply return -1 to indicate a failure. + // Notice that we'll NOT be able to run MXGetLastError + // to get the error message after this function fails. + return -1; + } + */ + return 0; +} + JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) { char *tmpstr = "MXNetError"; jstring rtstr = env->NewStringUTF(tmpstr); From 52733fcafa3de3d1e57613f7ea618b5c76c331fe Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 27 Dec 2015 18:07:02 +0800 Subject: [PATCH 10/13] typo fix --- scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 7706312bfd68..715d39fc67a8 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -89,7 +89,7 @@ class LibInfo { @native def mxExecutorForward(handle: ExecutorHandle, isTrain: Int): Int @native def mxExecutorBackward(handle: ExecutorHandle, gradsSize: Int, - // outs ought to be Array[NDArrayHandle], + // grads ought to be Array[NDArrayHandle], // we pass ptr address directly for performance consideration grads: Array[CPtrAddress]): Int @native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int From a03c22803884ab3bbdd506c3c0d9187849388995 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sun, 27 Dec 2015 18:29:14 +0800 Subject: [PATCH 11/13] link cblas for linux profile --- scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala | 2 +- scala-package/native/linux-x86_64/pom.xml | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala index 5046c8bbd7d2..bc48a0514af1 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala @@ -26,7 +26,7 @@ class Monitor(protected val interval: Int, protected var statFunc: (NDArray) => } private var activated: Boolean = false - private var queue = new mutable.Queue[(Int, String, NDArray)] + private var queue = new mutable.Queue[(Int, String, NDArray)] private var step: Int = 0 private var exes = new mutable.Queue[Executor] diff --git a/scala-package/native/linux-x86_64/pom.xml b/scala-package/native/linux-x86_64/pom.xml index ef1fd3c0f67b..da7bc5ffb969 100644 --- a/scala-package/native/linux-x86_64/pom.xml +++ b/scala-package/native/linux-x86_64/pom.xml @@ -112,9 +112,7 @@ -shared - - -fopenmp - + -fopenmp ${ldflags.blas} -Wl,--whole-archive ../../../lib/libmxnet.a -Wl,-no-whole-archive From 356ff475e01fcddf81608c95b6a5e16b9b8f2e41 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 28 Dec 2015 01:16:19 +0800 Subject: [PATCH 12/13] force load libblas in travis --- scala-package/native/linux-x86_64/pom.xml | 9 ++++++++- tests/travis/run_test.sh | 10 +++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/scala-package/native/linux-x86_64/pom.xml b/scala-package/native/linux-x86_64/pom.xml index da7bc5ffb969..398cf01219bd 100644 --- a/scala-package/native/linux-x86_64/pom.xml +++ b/scala-package/native/linux-x86_64/pom.xml @@ -56,7 +56,14 @@ - + + + + + + + + diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index f0d667ab213a..acdf2e479236 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -131,8 +131,14 @@ if [ ${TASK} == "scala_test" ]; then mvn integration-test -P osx-x86_64 --log-file scala_test_results.txt fi if [ ${TRAVIS_OS_NAME} == "linux" ]; then + # (Yizhi Liu) I'm not sure it is a proper solution, + # which is mentioned here: + # http://stackoverflow.com/questions/9558909/jni-symbol-lookup-error-in-shared-library-on-linux/13086028#13086028 + # I really don't know why we have to export LD_PRELOAD + # to make libblas loaded in travis. It just works. + export LD_PRELOAD=/usr/lib/libblas/libblas.so # use g++-4.8 for linux - mvn clean package -P linux-x86_64 -D cxx=g++-4.8 + mvn clean package -P linux-x86_64 -D cxx=g++-4.8 -D ldflags.blas=-lblas mvn integration-test -P linux-x86_64 --log-file scala_test_results.txt fi @@ -141,5 +147,3 @@ if [ ${TASK} == "scala_test" ]; then exit 0 fi - - From 213cd1084ee0f45cbc5f8129120dd12850a581a6 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 28 Dec 2015 01:29:22 +0800 Subject: [PATCH 13/13] linux travis scala-test should be success this time, FSM bless me --- tests/travis/run_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index acdf2e479236..f3919a06fa20 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -139,7 +139,7 @@ if [ ${TASK} == "scala_test" ]; then export LD_PRELOAD=/usr/lib/libblas/libblas.so # use g++-4.8 for linux mvn clean package -P linux-x86_64 -D cxx=g++-4.8 -D ldflags.blas=-lblas - mvn integration-test -P linux-x86_64 --log-file scala_test_results.txt + mvn integration-test -P linux-x86_64 -D cxx=g++-4.8 -D ldflags.blas=-lblas --log-file scala_test_results.txt fi chmod +x ../tests/travis/error_detector.sh