Skip to content

Commit

Permalink
KVStore setOptimizer, with JavaSerializer
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu committed Dec 19, 2015
1 parent ea1414f commit ebba851
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 31 deletions.
41 changes: 41 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,32 @@ class KVStore(private val handle: KVStoreHandle) {
pull(Array(key), Array(out), priority)
}

// Get the type of this kvstore
def getType: String = {
val kvType = new RefString
checkCall(_LIB.mxKVStoreGetType(handle, kvType))
kvType.value
}

/**
* Register an optimizer to the store
* If there are multiple machines, this process (should be a worker node)
* will pack this optimizer and send it to all servers. It returns after
* this action is done.
*
* @param optimizer the optimizer
*/
def setOptimizer(optimizer: Optimizer): Unit = {
val isWorker = new RefInt
checkCall(_LIB.mxKVStoreIsWorkerNode(isWorker))
if ("dist" == getType && isWorker.value != 0) {
val optSerialized = Serializer.getSerializer.serialize(optimizer)
_sendCommandToServers(0, Serializer.encodeBase64String(optSerialized))
} else {
setUpdater(Optimizer.getUpdater(optimizer))
}
}

/**
* Set a push updater into the store.
*
Expand All @@ -110,4 +136,19 @@ class KVStore(private val handle: KVStoreHandle) {
this.updaterFunc = updater
checkCall(_LIB.mxKVStoreSetUpdater(handle, updaterFunc, null))
}

/**
* Send a command to all server nodes
*
* Send a command to all server nodes, which will make each server node run
* KVStoreServer.controller
*
* This function returns after the command has been executed in all server nodes
*
* @param head the head of the command
* @param body the body of the command
*/
private def _sendCommandToServers(head: Int, body: String): Unit = {
checkCall(_LIB.mxKVStoreSendCommmandToServers(handle, head, body))
}
}
4 changes: 4 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,8 @@ class LibInfo {
@native def mxKVStoreSetUpdater(handle: KVStoreHandle,
updaterFunc: MXKVStoreUpdater,
updaterHandle: AnyRef): Int
@native def mxKVStoreIsWorkerNode(isWorker: RefInt): Int
@native def mxKVStoreGetType(handle: KVStoreHandle, kvType: RefString): Int
@native def mxKVStoreSendCommmandToServers(handle: KVStoreHandle,
head: Int, body: String): Int
}
27 changes: 25 additions & 2 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
package ml.dmlc.mxnet

class Optimizer {
object Optimizer {
def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = {
new MXKVStoreUpdater {
val states = new scala.collection.mutable.HashMap[Int, AnyRef]
override def update(index: Int, grad: NDArray, weight: NDArray, handle: AnyRef): Unit = {
val state = states.getOrElseUpdate(index, optimizer.createState(index, weight))
optimizer.update(index, weight, grad, state)
}
}
}
}

abstract class Optimizer extends Serializable {
/**
* Update the parameters.
* @param index An unique integer key used to index the parameters
* @param weight weight ndarray
* @param grad grad ndarray
* @param state NDArray or other objects returned by initState
* The auxiliary state used in optimization.
*/
def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = ???

// Create additional optimizer state such as momentum.
def createState(index: Int, weight: NDArray): AnyRef
}

trait MXKVStoreUpdater {
Expand All @@ -13,5 +36,5 @@ trait MXKVStoreUpdater {
* @param local the value stored on local on this key
* @param handle The additional handle to the updater
*/
def update(key: Int, recv: NDArray, local: NDArray, handle: AnyRef): Unit
def update(key: Int, recv: NDArray, local: NDArray, handle: AnyRef = null): Unit
}
58 changes: 58 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Serializer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package ml.dmlc.mxnet

import java.io._
import java.nio.ByteBuffer
import java.nio.charset.Charset
import java.util.Base64

import scala.reflect.ClassTag

/**
* Serialize & deserialize Java/Scala [[Serializable]] objects
* @author Yizhi Liu
*/
abstract class Serializer {
def serialize[T: ClassTag](t: T): ByteBuffer
def deserialize[T: ClassTag](bytes: ByteBuffer): T
}

object Serializer {
val UTF8 = Charset.forName("UTF-8")

def getSerializer: Serializer = getSerializer(None)

def getSerializer(serializer: Serializer): Serializer = {
// TODO: dynamically get from mxnet env to support other serializers like Kyro
if (serializer == null) new JavaSerializer else serializer
}

def getSerializer(serializer: Option[Serializer]): Serializer = {
// TODO: dynamically get from mxnet env to support other serializers like Kyro
serializer.getOrElse(new JavaSerializer)
}

def encodeBase64String(bytes: ByteBuffer): String = {
new String(Base64.getEncoder.encode(bytes).array, UTF8)
}

def decodeBase64String(str: String): ByteBuffer = {
ByteBuffer.wrap(Base64.getDecoder.decode(str.getBytes(UTF8)))
}
}

class JavaSerializer extends Serializer {
override def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = new ObjectOutputStream(bos)
out.writeObject(t)
out.close()
ByteBuffer.wrap(bos.toByteArray)
}

override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
val byteArray = bytes.array()
val bis = new ByteArrayInputStream(byteArray)
val in = new ObjectInputStream(bis)
in.readObject().asInstanceOf[T]
}
}
18 changes: 18 additions & 0 deletions scala-package/native/src/main/native/jni_helper_func.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include <jni.h>

#ifndef MXNET_SCALA_JNI_HELPER_FUNC_H
#define MXNET_SCALA_JNI_HELPER_FUNC_H

jlong getLongField(JNIEnv *env, jobject obj) {
jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refFid = env->GetFieldID(refClass, "value", "J");
return env->GetLongField(obj, refFid);
}

jint getIntField(JNIEnv *env, jobject obj) {
jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt");
jfieldID refFid = env->GetFieldID(refClass, "value", "I");
return env->GetIntField(obj, refFid);
}

#endif
59 changes: 30 additions & 29 deletions scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include <iostream>
#include <cstdlib>
#include <functional>
#include <jni.h>
#include <mxnet/ndarray.h>
#include <mxnet/kvstore.h>
#include <mxnet/c_api.h>

#include "jni_helper_func.h"
#include "ml_dmlc_mxnet_native_c_api.h" // generated by javah

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayCreateNone(JNIEnv *env, jobject obj, jobject ndArrayHandle) {
Expand Down Expand Up @@ -66,9 +66,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncDescribe(JNIEnv *env, jo
jobject nScalars,
jobject nMutateVars,
jobject typeMask) {
jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
jlong funcPtr = env->GetLongField(funcHandle, refLongFid);
jlong funcPtr = getLongField(env, funcHandle);

mx_uint numUseVars;
mx_uint numScalars;
Expand All @@ -95,9 +93,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncGetInfo(JNIEnv *env, job
jobject argNames,
jobject argTypes,
jobject argDescs) {
jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
jlong funcPtr = env->GetLongField(funcHandle, refLongFid);
jlong funcPtr = getLongField(env, funcHandle);

const char *cName;
const char *cDesc;
Expand Down Expand Up @@ -137,9 +133,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxFuncInvoke(JNIEnv *env, jobj
jlongArray useVars,
jfloatArray scalarArgs,
jlongArray mutateVars) {
jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
jlong funcPtr = env->GetLongField(funcHandle, refLongFid);
jlong funcPtr = getLongField(env, funcHandle);

jlong *cUseVars = env->GetLongArrayElements(useVars, NULL);
jfloat *cScalarArgs = env->GetFloatArrayElements(scalarArgs, NULL);
Expand All @@ -158,9 +152,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayGetShape(JNIEnv *env,
jobject ndArrayHandle,
jobject ndimRef,
jobject dataBuf) {
jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
jlong ndArrayPtr = env->GetLongField(ndArrayHandle, refLongFid);
jlong ndArrayPtr = getLongField(env, ndArrayHandle);

mx_uint ndim;
const mx_uint *pdata;
Expand Down Expand Up @@ -190,9 +182,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArraySyncCopyToCPU(JNIEnv
jobject ndArrayHandle,
jfloatArray data,
jint size) {
jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
jlong ndArrayPtr = env->GetLongField(ndArrayHandle, refLongFid);
jlong ndArrayPtr = getLongField(env, ndArrayHandle);

jfloat *pdata = env->GetFloatArrayElements(data, NULL);
int ret = MXNDArraySyncCopyToCPU((NDArrayHandle)ndArrayPtr, (mx_float *)pdata, size);
Expand Down Expand Up @@ -224,9 +214,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSetUpdater(JNIEnv *en
jobject updaterFuncObj,
jobject updaterHandle) {
// get kv store ptr
jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
jlong kvStorePtr = env->GetLongField(kvStoreHandle, refLongFid);
jlong kvStorePtr = getLongField(env, kvStoreHandle);

jobject updaterFuncObjGlb = env->NewGlobalRef(updaterFuncObj);
jobject updaterHandleGlb = env->NewGlobalRef(updaterHandle);
Expand Down Expand Up @@ -290,9 +278,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreInit(JNIEnv *env, job
jintArray keys,
jlongArray values) {
// get kv store ptr
jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
jlong kvStorePtr = env->GetLongField(kvStoreHandle, refLongFid);
jlong kvStorePtr = getLongField(env, kvStoreHandle);

jint *keyArray = env->GetIntArrayElements(keys, NULL);
jlong *valueArray = env->GetLongArrayElements(values, NULL);
Expand All @@ -312,9 +298,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePush(JNIEnv *env, job
jlongArray values,
jint priority) {
// get kv store ptr
jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
jlong kvStorePtr = env->GetLongField(kvStoreHandle, refLongFid);
jlong kvStorePtr = getLongField(env, kvStoreHandle);

jint *keyArray = env->GetIntArrayElements(keys, NULL);
jlong *valueArray = env->GetLongArrayElements(values, NULL);
Expand All @@ -335,9 +319,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull(JNIEnv *env, job
jlongArray outs,
jint priority) {
// get kv store ptr
jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");
jlong kvStorePtr = env->GetLongField(kvStoreHandle, refLongFid);
jlong kvStorePtr = getLongField(env, kvStoreHandle);

jint *keyArray = env->GetIntArrayElements(keys, NULL);
jlong *outArray = env->GetLongArrayElements(outs, NULL);
Expand All @@ -351,7 +333,26 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull(JNIEnv *env, job
return ret;
}

// TODO: move to c_api_error.c
JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetType
(JNIEnv *env, jobject obj, jobject kvStoreHandle, jobject kvType) {
jlong kvStorePtr = getLongField(env, kvStoreHandle);
const char *type;
int ret = MXKVStoreGetType((KVStoreHandle)kvStorePtr, &type);
jclass refStringClass = env->FindClass("ml/dmlc/mxnet/Base$RefString");
jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;");
env->SetObjectField(kvType, valueStr, env->NewStringUTF(type));
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSendCommmandToServers
(JNIEnv *env, jobject obj, jobject kvStoreHandle, jint head, jstring body) {
jlong kvStorePtr = getLongField(env, kvStoreHandle);
const char *bodyCStr = env->GetStringUTFChars(body, 0);
int ret = MXKVStoreSendCommmandToServers((KVStoreHandle)kvStorePtr, head, bodyCStr);
env->ReleaseStringUTFChars(body, bodyCStr);
return ret;
}

JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) {
char *tmpstr = "MXNetError";
jstring rtstr = env->NewStringUTF(tmpstr);
Expand Down

0 comments on commit ebba851

Please sign in to comment.