diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala index f588aa20532b..cdb62e6d18a1 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala @@ -98,6 +98,32 @@ class KVStore(private val handle: KVStoreHandle) { pull(Array(key), Array(out), priority) } + // Get the type of this kvstore + def getType: String = { + val kvType = new RefString + checkCall(_LIB.mxKVStoreGetType(handle, kvType)) + kvType.value + } + + /** + * Register an optimizer to the store + * If there are multiple machines, this process (should be a worker node) + * will pack this optimizer and send it to all servers. It returns after + * this action is done. + * + * @param optimizer the optimizer + */ + def setOptimizer(optimizer: Optimizer): Unit = { + val isWorker = new RefInt + checkCall(_LIB.mxKVStoreIsWorkerNode(isWorker)) + if ("dist" == getType && isWorker.value != 0) { + val optSerialized = Serializer.getSerializer.serialize(optimizer) + _sendCommandToServers(0, Serializer.encodeBase64String(optSerialized)) + } else { + setUpdater(Optimizer.getUpdater(optimizer)) + } + } + /** * Set a push updater into the store. * @@ -110,4 +136,19 @@ class KVStore(private val handle: KVStoreHandle) { this.updaterFunc = updater checkCall(_LIB.mxKVStoreSetUpdater(handle, updaterFunc, null)) } + + /** + * Send a command to all server nodes + * + * Send a command to all server nodes, which will make each server node run + * KVStoreServer.controller + * + * This function returns after the command has been executed in all server nodes + * + * @param head the head of the command + * @param body the body of the command + */ + private def _sendCommandToServers(head: Int, body: String): Unit = { + checkCall(_LIB.mxKVStoreSendCommmandToServers(handle, head, body)) + } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 18810b5440a5..c32877c571e2 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -71,4 +71,8 @@ class LibInfo { @native def mxKVStoreSetUpdater(handle: KVStoreHandle, updaterFunc: MXKVStoreUpdater, updaterHandle: AnyRef): Int + @native def mxKVStoreIsWorkerNode(isWorker: RefInt): Int + @native def mxKVStoreGetType(handle: KVStoreHandle, kvType: RefString): Int + @native def mxKVStoreSendCommmandToServers(handle: KVStoreHandle, + head: Int, body: String): Int } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala index 7e23639bdad1..14d5dbd1e117 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala @@ -1,7 +1,30 @@ package ml.dmlc.mxnet -class Optimizer { +object Optimizer { + def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = { + new MXKVStoreUpdater { + val states = new scala.collection.mutable.HashMap[Int, AnyRef] + override def update(index: Int, grad: NDArray, weight: NDArray, handle: AnyRef): Unit = { + val state = states.getOrElseUpdate(index, optimizer.createState(index, weight)) + optimizer.update(index, weight, grad, state) + } + } + } +} + +abstract class Optimizer extends Serializable { + /** + * Update the parameters. + * @param index An unique integer key used to index the parameters + * @param weight weight ndarray + * @param grad grad ndarray + * @param state NDArray or other objects returned by initState + * The auxiliary state used in optimization. + */ + def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = ??? + // Create additional optimizer state such as momentum. + def createState(index: Int, weight: NDArray): AnyRef } trait MXKVStoreUpdater { @@ -13,5 +36,5 @@ trait MXKVStoreUpdater { * @param local the value stored on local on this key * @param handle The additional handle to the updater */ - def update(key: Int, recv: NDArray, local: NDArray, handle: AnyRef): Unit + def update(key: Int, recv: NDArray, local: NDArray, handle: AnyRef = null): Unit } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Serializer.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Serializer.scala new file mode 100644 index 000000000000..21d2410911a6 --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Serializer.scala @@ -0,0 +1,58 @@ +package ml.dmlc.mxnet + +import java.io._ +import java.nio.ByteBuffer +import java.nio.charset.Charset +import java.util.Base64 + +import scala.reflect.ClassTag + +/** + * Serialize & deserialize Java/Scala [[Serializable]] objects + * @author Yizhi Liu + */ +abstract class Serializer { + def serialize[T: ClassTag](t: T): ByteBuffer + def deserialize[T: ClassTag](bytes: ByteBuffer): T +} + +object Serializer { + val UTF8 = Charset.forName("UTF-8") + + def getSerializer: Serializer = getSerializer(None) + + def getSerializer(serializer: Serializer): Serializer = { + // TODO: dynamically get from mxnet env to support other serializers like Kyro + if (serializer == null) new JavaSerializer else serializer + } + + def getSerializer(serializer: Option[Serializer]): Serializer = { + // TODO: dynamically get from mxnet env to support other serializers like Kyro + serializer.getOrElse(new JavaSerializer) + } + + def encodeBase64String(bytes: ByteBuffer): String = { + new String(Base64.getEncoder.encode(bytes).array, UTF8) + } + + def decodeBase64String(str: String): ByteBuffer = { + ByteBuffer.wrap(Base64.getDecoder.decode(str.getBytes(UTF8))) + } +} + +class JavaSerializer extends Serializer { + override def serialize[T: ClassTag](t: T): ByteBuffer = { + val bos = new ByteArrayOutputStream() + val out = new ObjectOutputStream(bos) + out.writeObject(t) + out.close() + ByteBuffer.wrap(bos.toByteArray) + } + + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { + val byteArray = bytes.array() + val bis = new ByteArrayInputStream(byteArray) + val in = new ObjectInputStream(bis) + in.readObject().asInstanceOf[T] + } +} diff --git a/scala-package/native/src/main/native/jni_helper_func.h b/scala-package/native/src/main/native/jni_helper_func.h new file mode 100644 index 000000000000..864733d2be60 --- /dev/null +++ b/scala-package/native/src/main/native/jni_helper_func.h @@ -0,0 +1,18 @@ +#include + +#ifndef MXNET_SCALA_JNI_HELPER_FUNC_H +#define MXNET_SCALA_JNI_HELPER_FUNC_H + +jlong getLongField(JNIEnv *env, jobject obj) { + jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); + jfieldID refFid = env->GetFieldID(refClass, "value", "J"); + return env->GetLongField(obj, refFid); +} + +jint getIntField(JNIEnv *env, jobject obj) { + jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt"); + jfieldID refFid = env->GetFieldID(refClass, "value", "I"); + return env->GetIntField(obj, refFid); +} + +#endif diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index f711dff91dcd..3883ab84a562 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -1,11 +1,11 @@ #include #include #include -#include #include #include #include +#include "jni_helper_func.h" #include "ml_dmlc_mxnet_native_c_api.h" // generated by javah JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayCreateNone(JNIEnv *env, jobject obj, jobject ndArrayHandle) { @@ -66,9 +66,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncDescribe(JNIEnv *env, jo jobject nScalars, jobject nMutateVars, jobject typeMask) { - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong funcPtr = env->GetLongField(funcHandle, refLongFid); + jlong funcPtr = getLongField(env, funcHandle); mx_uint numUseVars; mx_uint numScalars; @@ -95,9 +93,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncGetInfo(JNIEnv *env, job jobject argNames, jobject argTypes, jobject argDescs) { - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong funcPtr = env->GetLongField(funcHandle, refLongFid); + jlong funcPtr = getLongField(env, funcHandle); const char *cName; const char *cDesc; @@ -137,9 +133,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncInvoke(JNIEnv *env, jobj jlongArray useVars, jfloatArray scalarArgs, jlongArray mutateVars) { - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong funcPtr = env->GetLongField(funcHandle, refLongFid); + jlong funcPtr = getLongField(env, funcHandle); jlong *cUseVars = env->GetLongArrayElements(useVars, NULL); jfloat *cScalarArgs = env->GetFloatArrayElements(scalarArgs, NULL); @@ -158,9 +152,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayGetShape(JNIEnv *env, jobject ndArrayHandle, jobject ndimRef, jobject dataBuf) { - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong ndArrayPtr = env->GetLongField(ndArrayHandle, refLongFid); + jlong ndArrayPtr = getLongField(env, ndArrayHandle); mx_uint ndim; const mx_uint *pdata; @@ -190,9 +182,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySyncCopyToCPU(JNIEnv jobject ndArrayHandle, jfloatArray data, jint size) { - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong ndArrayPtr = env->GetLongField(ndArrayHandle, refLongFid); + jlong ndArrayPtr = getLongField(env, ndArrayHandle); jfloat *pdata = env->GetFloatArrayElements(data, NULL); int ret = MXNDArraySyncCopyToCPU((NDArrayHandle)ndArrayPtr, (mx_float *)pdata, size); @@ -224,9 +214,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSetUpdater(JNIEnv *en jobject updaterFuncObj, jobject updaterHandle) { // get kv store ptr - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong kvStorePtr = env->GetLongField(kvStoreHandle, refLongFid); + jlong kvStorePtr = getLongField(env, kvStoreHandle); jobject updaterFuncObjGlb = env->NewGlobalRef(updaterFuncObj); jobject updaterHandleGlb = env->NewGlobalRef(updaterHandle); @@ -290,9 +278,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreInit(JNIEnv *env, job jintArray keys, jlongArray values) { // get kv store ptr - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong kvStorePtr = env->GetLongField(kvStoreHandle, refLongFid); + jlong kvStorePtr = getLongField(env, kvStoreHandle); jint *keyArray = env->GetIntArrayElements(keys, NULL); jlong *valueArray = env->GetLongArrayElements(values, NULL); @@ -312,9 +298,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePush(JNIEnv *env, job jlongArray values, jint priority) { // get kv store ptr - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong kvStorePtr = env->GetLongField(kvStoreHandle, refLongFid); + jlong kvStorePtr = getLongField(env, kvStoreHandle); jint *keyArray = env->GetIntArrayElements(keys, NULL); jlong *valueArray = env->GetLongArrayElements(values, NULL); @@ -335,9 +319,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull(JNIEnv *env, job jlongArray outs, jint priority) { // get kv store ptr - jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); - jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); - jlong kvStorePtr = env->GetLongField(kvStoreHandle, refLongFid); + jlong kvStorePtr = getLongField(env, kvStoreHandle); jint *keyArray = env->GetIntArrayElements(keys, NULL); jlong *outArray = env->GetLongArrayElements(outs, NULL); @@ -351,7 +333,26 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull(JNIEnv *env, job return ret; } -// TODO: move to c_api_error.c +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetType + (JNIEnv *env, jobject obj, jobject kvStoreHandle, jobject kvType) { + jlong kvStorePtr = getLongField(env, kvStoreHandle); + const char *type; + int ret = MXKVStoreGetType((KVStoreHandle)kvStorePtr, &type); + jclass refStringClass = env->FindClass("ml/dmlc/mxnet/Base$RefString"); + jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;"); + env->SetObjectField(kvType, valueStr, env->NewStringUTF(type)); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSendCommmandToServers + (JNIEnv *env, jobject obj, jobject kvStoreHandle, jint head, jstring body) { + jlong kvStorePtr = getLongField(env, kvStoreHandle); + const char *bodyCStr = env->GetStringUTFChars(body, 0); + int ret = MXKVStoreSendCommmandToServers((KVStoreHandle)kvStorePtr, head, bodyCStr); + env->ReleaseStringUTFChars(body, bodyCStr); + return ret; +} + JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) { char *tmpstr = "MXNetError"; jstring rtstr = env->NewStringUTF(tmpstr);