Skip to content

Commit

Permalink
add KVStore rank & numWorkers
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhliu committed Dec 20, 2015
1 parent ebba851 commit 96cd865
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 6 deletions.
4 changes: 4 additions & 0 deletions scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
</dependency>
<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
</dependency>
<dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
Expand Down
36 changes: 34 additions & 2 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,32 @@ class KVStore(private val handle: KVStoreHandle) {
}

// Get the type of this kvstore
def getType: String = {
def `type`: String = {
val kvType = new RefString
checkCall(_LIB.mxKVStoreGetType(handle, kvType))
kvType.value
}

/**
* Get the number of worker nodes
* @return The number of worker nodes
*/
def numWorkers: Int = {
val size = new RefInt
checkCall(_LIB.mxKVStoreGetGroupSize(handle, size))
size.value
}

/**
* Get the rank of this worker node
* @return The rank of this node, which is in [0, get_num_workers())
*/
def rank: Int = {
val rank = new RefInt
checkCall(_LIB.mxKVStoreGetRank(handle, rank))
rank.value
}

/**
* Register an optimizer to the store
* If there are multiple machines, this process (should be a worker node)
Expand All @@ -116,7 +136,7 @@ class KVStore(private val handle: KVStoreHandle) {
def setOptimizer(optimizer: Optimizer): Unit = {
val isWorker = new RefInt
checkCall(_LIB.mxKVStoreIsWorkerNode(isWorker))
if ("dist" == getType && isWorker.value != 0) {
if ("dist" == `type` && isWorker.value != 0) {
val optSerialized = Serializer.getSerializer.serialize(optimizer)
_sendCommandToServers(0, Serializer.encodeBase64String(optSerialized))
} else {
Expand All @@ -137,6 +157,18 @@ class KVStore(private val handle: KVStoreHandle) {
checkCall(_LIB.mxKVStoreSetUpdater(handle, updaterFunc, null))
}

/**
* Global barrier among all worker nodes
*
* For example, assume there are n machines, we want to let machine 0 first
* init the values, and then pull the inited value to all machines. Before
* pulling, we can place a barrier to guarantee that the initialization is
* finished.
*/
def barrier() {
checkCall(_LIB.mxKVStoreBarrier(handle))
}

/**
* Send a command to all server nodes
*
Expand Down
3 changes: 3 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,7 @@ class LibInfo {
@native def mxKVStoreGetType(handle: KVStoreHandle, kvType: RefString): Int
@native def mxKVStoreSendCommmandToServers(handle: KVStoreHandle,
head: Int, body: String): Int
@native def mxKVStoreBarrier(handle: KVStoreHandle): Int
@native def mxKVStoreGetGroupSize(handle: KVStoreHandle, size: RefInt): Int
@native def mxKVStoreGetRank(handle: KVStoreHandle, size: RefInt): Int
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package ml.dmlc.mxnet
object Optimizer {
def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = {
new MXKVStoreUpdater {
val states = new scala.collection.mutable.HashMap[Int, AnyRef]
private val states = new scala.collection.mutable.HashMap[Int, AnyRef]
override def update(index: Int, grad: NDArray, weight: NDArray, handle: AnyRef): Unit = {
val state = states.getOrElseUpdate(index, optimizer.createState(index, weight))
optimizer.update(index, weight, grad, state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ package ml.dmlc.mxnet
import java.io._
import java.nio.ByteBuffer
import java.nio.charset.Charset
import java.util.Base64

import org.apache.commons.codec.binary.Base64

import scala.reflect.ClassTag

Expand Down Expand Up @@ -32,11 +33,11 @@ object Serializer {
}

def encodeBase64String(bytes: ByteBuffer): String = {
new String(Base64.getEncoder.encode(bytes).array, UTF8)
new String(Base64.encodeBase64(bytes.array), UTF8)
}

def decodeBase64String(str: String): ByteBuffer = {
ByteBuffer.wrap(Base64.getDecoder.decode(str.getBytes(UTF8)))
ByteBuffer.wrap(Base64.decodeBase64(str.getBytes(UTF8)))
}
}

Expand Down
11 changes: 11 additions & 0 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,15 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll {
kv.pull(3, ndArray)
assert(ndArray.toArray === Array(6f, 6f))
}

test("get type") {
val kv = KVStore.create("local")
assert(kv.`type` === "local")
}

test("get numWorkers and rank") {
val kv = KVStore.create("local")
assert(kv.numWorkers === 1)
assert(kv.rank === 0)
}
}
6 changes: 6 additions & 0 deletions scala-package/native/src/main/native/jni_helper_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,10 @@ jint getIntField(JNIEnv *env, jobject obj) {
return env->GetIntField(obj, refFid);
}

void setIntField(JNIEnv *env, jobject obj, jint value) {
jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt");
jfieldID refFid = env->GetFieldID(refClass, "value", "I");
env->SetIntField(obj, refFid, value);
}

#endif
24 changes: 24 additions & 0 deletions scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,30 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreSendCommmandToServers
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreBarrier
(JNIEnv *env, jobject obj, jobject kvStoreHandle) {
jlong kvStorePtr = getLongField(env, kvStoreHandle);
return MXKVStoreBarrier((KVStoreHandle)kvStorePtr);
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetGroupSize
(JNIEnv *env, jobject obj, jobject kvStoreHandle, jobject sizeRef) {
jlong kvStorePtr = getLongField(env, kvStoreHandle);
int size;
int ret = MXKVStoreGetGroupSize((KVStoreHandle)kvStorePtr, &size);
setIntField(env, sizeRef, size);
return ret;
}

JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxKVStoreGetRank
(JNIEnv *env, jobject obj, jobject kvStoreHandle, jobject rankRef) {
jlong kvStorePtr = getLongField(env, kvStoreHandle);
int rank;
int ret = MXKVStoreGetRank((KVStoreHandle)kvStorePtr, &rank);
setIntField(env, rankRef, rank);
return ret;
}

JNIEXPORT jstring JNICALL Java_ml_dmlc_mxnet_LibInfo_mxGetLastError(JNIEnv * env, jobject obj) {
char *tmpstr = "MXNetError";
jstring rtstr = env->NewStringUTF(tmpstr);
Expand Down
5 changes: 5 additions & 0 deletions scala-package/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
<version>1.10</version>
</dependency>
<dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
Expand Down

0 comments on commit 96cd865

Please sign in to comment.