Skip to content

Commit

Permalink
Merge pull request apache#21 from javelinjs/scala-package-l
Browse files Browse the repository at this point in the history
Change Executor & KVStore handle to pure long
  • Loading branch information
terrytangyuan committed Jan 10, 2016
2 parents 18e68fe + 63cd3c6 commit b63cf23
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 71 deletions.
8 changes: 5 additions & 3 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,22 @@ object Base {
type MXFloat = Float
type CPtrAddress = Long

type SymbolHandle = CPtrAddress
type NDArrayHandle = CPtrAddress
type FunctionHandle = CPtrAddress
type DataIterHandle = CPtrAddress
type DataIterCreator = CPtrAddress
type KVStoreHandle = CPtrAddress
type ExecutorHandle = CPtrAddress
type SymbolHandle = CPtrAddress

type MXUintRef = RefInt
type MXFloatRef = RefFloat
type NDArrayHandleRef = RefLong
type FunctionHandleRef = RefLong
type DataIterHandleRef = RefLong
type DataIterCreatorRef = RefLong
type KVStoreHandle = RefLong
type ExecutorHandle = RefLong
type KVStoreHandleRef = RefLong
type ExecutorHandleRef = RefLong
type SymbolHandleRef = RefLong

System.loadLibrary("mxnet-scala")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ object Executor {
* @see Symbol.bind : to create executor
*/
// scalastyle:off finalize
class Executor(val handle: ExecutorHandle, val symbol: Symbol) {
class Executor(private[mxnet] val handle: ExecutorHandle, private[mxnet] val symbol: Symbol) {
var argArrays: Array[NDArray] = null
protected var gradArrays: Array[NDArray] = null
protected var auxArrays: Array[NDArray] = null
Expand Down Expand Up @@ -136,7 +136,7 @@ class Executor(val handle: ExecutorHandle, val symbol: Symbol) {
def backward(outGrads: Array[NDArray]): Unit = {
require(outGrads != null)
val ndArrayPtrs = outGrads.map(_.handle)
checkCall(_LIB.mxExecutorBackward(handle, outGrads.length, ndArrayPtrs))
checkCall(_LIB.mxExecutorBackward(handle, ndArrayPtrs))
}

def backward(outGrad: NDArray): Unit = {
Expand Down
8 changes: 4 additions & 4 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ object KVStore {
* @return The created KVStore
*/
def create(name: String = "local"): KVStore = {
val handle = new KVStoreHandle
val handle = new KVStoreHandleRef
checkCall(_LIB.mxKVStoreCreate(name, handle))
new KVStore(handle)
new KVStore(handle.value)
}
}

Expand Down Expand Up @@ -156,7 +156,7 @@ class KVStore(private val handle: KVStoreHandle) {
checkCall(_LIB.mxKVStoreIsWorkerNode(isWorker))
if ("dist" == `type` && isWorker.value != 0) {
val optSerialized = Serializer.getSerializer.serialize(optimizer)
_sendCommandToServers(0, Serializer.encodeBase64String(optSerialized))
sendCommandToServers(0, Serializer.encodeBase64String(optSerialized))
} else {
setUpdater(Optimizer.getUpdater(optimizer))
}
Expand Down Expand Up @@ -198,7 +198,7 @@ class KVStore(private val handle: KVStoreHandle) {
* @param head the head of the command
* @param body the body of the command
*/
private def _sendCommandToServers(head: Int, body: String): Unit = {
private def sendCommandToServers(head: Int, body: String): Unit = {
checkCall(_LIB.mxKVStoreSendCommmandToServers(handle, head, body))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class LibInfo {
@native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle,
source: Array[MXFloat],
size: Int): Int
@native def mxKVStoreCreate(name: String, handle: KVStoreHandle): Int
@native def mxKVStoreCreate(name: String, handle: KVStoreHandleRef): Int
@native def mxKVStoreInit(handle: KVStoreHandle,
len: MXUint,
keys: Array[Int],
Expand Down Expand Up @@ -104,7 +104,6 @@ class LibInfo {
@native def mxExecutorFree(handle: ExecutorHandle): Int
@native def mxExecutorForward(handle: ExecutorHandle, isTrain: Int): Int
@native def mxExecutorBackward(handle: ExecutorHandle,
gradsSize: Int,
grads: Array[NDArrayHandle]): Int
@native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int
@native def mxExecutorSetMonitorCallback(handle: ExecutorHandle, callback: MXMonitorCallback): Int
Expand Down
85 changes: 25 additions & 60 deletions scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree
// Thus we have to wrap the function in a java object, and run env->CallVoidMethod(obj) once updater is invoked,
// which implies the function registered to KVStore must be stateful.
// This is why we re-implement MXKVStoreSetUpdater as follows.
JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSetUpdater(JNIEnv *env, jobject obj,
jobject kvStoreHandle,
jobject updaterFuncObj,
jobject updaterHandle) {
// get kv store ptr
jlong kvStorePtr = getLongField(env, kvStoreHandle);

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSetUpdater
(JNIEnv *env, jobject obj, jlong kvStorePtr, jobject updaterFuncObj, jobject updaterHandle) {
jobject updaterFuncObjGlb = env->NewGlobalRef(updaterFuncObj);
jobject updaterHandleGlb = env->NewGlobalRef(updaterHandle);
std::function<void(int, const mxnet::NDArray&, mxnet::NDArray*)> updt
Expand Down Expand Up @@ -237,9 +232,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSetUpdater(JNIEnv *en
return 0;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreCreate(JNIEnv *env, jobject obj,
jstring name,
jobject kvStoreHandle) {
JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreCreate
(JNIEnv *env, jobject obj, jstring name, jobject kvStoreHandle) {
jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J");

Expand All @@ -252,14 +246,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreCreate(JNIEnv *env, j
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreInit(JNIEnv *env, jobject obj,
jobject kvStoreHandle,
jint len,
jintArray keys,
jlongArray values) {
// get kv store ptr
jlong kvStorePtr = getLongField(env, kvStoreHandle);

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreInit
(JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys, jlongArray values) {
jint *keyArray = env->GetIntArrayElements(keys, NULL);
jlong *valueArray = env->GetLongArrayElements(values, NULL);
int ret = MXKVStoreInit((KVStoreHandle) kvStorePtr,
Expand All @@ -271,15 +259,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreInit(JNIEnv *env, job
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePush(JNIEnv *env, jobject obj,
jobject kvStoreHandle,
jint len,
jintArray keys,
jlongArray values,
jint priority) {
// get kv store ptr
jlong kvStorePtr = getLongField(env, kvStoreHandle);

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePush
(JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys,
jlongArray values, jint priority) {
jint *keyArray = env->GetIntArrayElements(keys, NULL);
jlong *valueArray = env->GetLongArrayElements(values, NULL);
int ret = MXKVStorePush((KVStoreHandle)kvStorePtr,
Expand All @@ -292,15 +274,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePush(JNIEnv *env, job
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull(JNIEnv *env, jobject obj,
jobject kvStoreHandle,
jint len,
jintArray keys,
jlongArray outs,
jint priority) {
// get kv store ptr
jlong kvStorePtr = getLongField(env, kvStoreHandle);

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull
(JNIEnv *env, jobject obj, jlong kvStorePtr, jint len, jintArray keys,
jlongArray outs, jint priority) {
jint *keyArray = env->GetIntArrayElements(keys, NULL);
jlong *outArray = env->GetLongArrayElements(outs, NULL);
int ret = MXKVStorePull((KVStoreHandle)kvStorePtr,
Expand All @@ -314,8 +290,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStorePull(JNIEnv *env, job
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetType
(JNIEnv *env, jobject obj, jobject kvStoreHandle, jobject kvType) {
jlong kvStorePtr = getLongField(env, kvStoreHandle);
(JNIEnv *env, jobject obj, jlong kvStorePtr, jobject kvType) {
const char *type;
int ret = MXKVStoreGetType((KVStoreHandle)kvStorePtr, &type);
jclass refStringClass = env->FindClass("ml/dmlc/mxnet/Base$RefString");
Expand All @@ -325,42 +300,36 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetType
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSendCommmandToServers
(JNIEnv *env, jobject obj, jobject kvStoreHandle, jint head, jstring body) {
jlong kvStorePtr = getLongField(env, kvStoreHandle);
(JNIEnv *env, jobject obj, jlong kvStorePtr, jint head, jstring body) {
const char *bodyCStr = env->GetStringUTFChars(body, 0);
int ret = MXKVStoreSendCommmandToServers((KVStoreHandle)kvStorePtr, head, bodyCStr);
env->ReleaseStringUTFChars(body, bodyCStr);
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreBarrier
(JNIEnv *env, jobject obj, jobject kvStoreHandle) {
jlong kvStorePtr = getLongField(env, kvStoreHandle);
(JNIEnv *env, jobject obj, jlong kvStorePtr) {
return MXKVStoreBarrier((KVStoreHandle)kvStorePtr);
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetGroupSize
(JNIEnv *env, jobject obj, jobject kvStoreHandle, jobject sizeRef) {
jlong kvStorePtr = getLongField(env, kvStoreHandle);
(JNIEnv *env, jobject obj, jlong kvStorePtr, jobject sizeRef) {
int size;
int ret = MXKVStoreGetGroupSize((KVStoreHandle)kvStorePtr, &size);
setIntField(env, sizeRef, size);
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetRank
(JNIEnv *env, jobject obj, jobject kvStoreHandle, jobject rankRef) {
jlong kvStorePtr = getLongField(env, kvStoreHandle);
(JNIEnv *env, jobject obj, jlong kvStorePtr, jobject rankRef) {
int rank;
int ret = MXKVStoreGetRank((KVStoreHandle)kvStorePtr, &rank);
setIntField(env, rankRef, rank);
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorOutputs
(JNIEnv *env, jobject obj, jobject executorHandle, jobject outputs) {

jlong executorPtr = getLongField(env, executorHandle);
(JNIEnv *env, jobject obj, jlong executorPtr, jobject outputs) {
mx_uint outSize;
NDArrayHandle *out;
int ret = MXExecutorOutputs((ExecutorHandle)executorPtr, &outSize, &out);
Expand All @@ -381,20 +350,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorOutputs
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorFree
(JNIEnv * env, jobject obj, jobject handle) {
jlong ptr = getLongField(env, handle);
(JNIEnv * env, jobject obj, jlong ptr) {
return MXExecutorFree((ExecutorHandle) ptr);
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorForward
(JNIEnv * env, jobject obj, jobject handle, jint isTrain) {
jlong ptr = getLongField(env, handle);
(JNIEnv * env, jobject obj, jlong ptr, jint isTrain) {
return MXExecutorForward((ExecutorHandle)ptr, (int)isTrain);
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorBackward
(JNIEnv * env, jobject obj, jobject handle, jint gradsSize, jlongArray grads) {
jlong executorPtr = getLongField(env, handle);
(JNIEnv * env, jobject obj, jlong executorPtr, jlongArray grads) {
int gradsSize = env->GetArrayLength(grads);
jlong *gradArr = env->GetLongArrayElements(grads, NULL);
int ret = MXExecutorBackward((ExecutorHandle)executorPtr,
(mx_uint)gradsSize,
Expand All @@ -404,17 +371,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorBackward
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorPrint
(JNIEnv * env, jobject obj, jobject handle, jobject debugStr) {
jlong ptr = getLongField(env, handle);
(JNIEnv * env, jobject obj, jlong ptr, jobject debugStr) {
const char *retDebugStr;
int ret = MXExecutorPrint((ExecutorHandle)handle, &retDebugStr);
int ret = MXExecutorPrint((ExecutorHandle)ptr, &retDebugStr);
setStringField(env, debugStr, retDebugStr);
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorSetMonitorCallback
(JNIEnv *env, jobject obj, jobject handle, jobject callbackFuncObj) {
jlong executorPtr = getLongField(env, handle);
(JNIEnv *env, jobject obj, jlong executorPtr, jobject callbackFuncObj) {
jobject callbackFuncObjGlb = env->NewGlobalRef(callbackFuncObj);
std::function<void(const char *, NDArrayHandle)> callback
= [env, callbackFuncObjGlb](const char *name, NDArrayHandle array) {
Expand Down

0 comments on commit b63cf23

Please sign in to comment.