Skip to content

Commit

Permalink
Merge pull request apache#6 from javelinjs/scala-package-cc
Browse files Browse the repository at this point in the history
KVStore all functions finished
  • Loading branch information
yanqingmen committed Dec 20, 2015
2 parents ea1414f + 96cd865 commit e78cc31
Show file tree
Hide file tree
Showing 9 changed files with 262 additions and 31 deletions.
4 changes: 4 additions & 0 deletions scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
</dependency>
<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
</dependency>
<dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
Expand Down
73 changes: 73 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/KVStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,52 @@ class KVStore(private val handle: KVStoreHandle) {
pull(Array(key), Array(out), priority)
}

// Get the type of this kvstore
def `type`: String = {
val kvType = new RefString
checkCall(_LIB.mxKVStoreGetType(handle, kvType))
kvType.value
}

/**
* Get the number of worker nodes
* @return The number of worker nodes
*/
def numWorkers: Int = {
val size = new RefInt
checkCall(_LIB.mxKVStoreGetGroupSize(handle, size))
size.value
}

/**
* Get the rank of this worker node
* @return The rank of this node, which is in [0, get_num_workers())
*/
def rank: Int = {
val rank = new RefInt
checkCall(_LIB.mxKVStoreGetRank(handle, rank))
rank.value
}

/**
* Register an optimizer to the store
* If there are multiple machines, this process (should be a worker node)
* will pack this optimizer and send it to all servers. It returns after
* this action is done.
*
* @param optimizer the optimizer
*/
def setOptimizer(optimizer: Optimizer): Unit = {
val isWorker = new RefInt
checkCall(_LIB.mxKVStoreIsWorkerNode(isWorker))
if ("dist" == `type` && isWorker.value != 0) {
val optSerialized = Serializer.getSerializer.serialize(optimizer)
_sendCommandToServers(0, Serializer.encodeBase64String(optSerialized))
} else {
setUpdater(Optimizer.getUpdater(optimizer))
}
}

/**
* Set a push updater into the store.
*
Expand All @@ -110,4 +156,31 @@ class KVStore(private val handle: KVStoreHandle) {
this.updaterFunc = updater
checkCall(_LIB.mxKVStoreSetUpdater(handle, updaterFunc, null))
}

/**
* Global barrier among all worker nodes
*
* For example, assume there are n machines, we want to let machine 0 first
* init the values, and then pull the inited value to all machines. Before
* pulling, we can place a barrier to guarantee that the initialization is
* finished.
*/
def barrier() {
checkCall(_LIB.mxKVStoreBarrier(handle))
}

/**
* Send a command to all server nodes
*
* Send a command to all server nodes, which will make each server node run
* KVStoreServer.controller
*
* This function returns after the command has been executed in all server nodes
*
* @param head the head of the command
* @param body the body of the command
*/
private def _sendCommandToServers(head: Int, body: String): Unit = {
checkCall(_LIB.mxKVStoreSendCommmandToServers(handle, head, body))
}
}
7 changes: 7 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,11 @@ class LibInfo {
@native def mxKVStoreSetUpdater(handle: KVStoreHandle,
updaterFunc: MXKVStoreUpdater,
updaterHandle: AnyRef): Int
@native def mxKVStoreIsWorkerNode(isWorker: RefInt): Int
@native def mxKVStoreGetType(handle: KVStoreHandle, kvType: RefString): Int
@native def mxKVStoreSendCommmandToServers(handle: KVStoreHandle,
head: Int, body: String): Int
@native def mxKVStoreBarrier(handle: KVStoreHandle): Int
@native def mxKVStoreGetGroupSize(handle: KVStoreHandle, size: RefInt): Int
@native def mxKVStoreGetRank(handle: KVStoreHandle, size: RefInt): Int
}
27 changes: 25 additions & 2 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Optimizer.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
package ml.dmlc.mxnet

class Optimizer {
object Optimizer {
def getUpdater(optimizer: Optimizer): MXKVStoreUpdater = {
new MXKVStoreUpdater {
private val states = new scala.collection.mutable.HashMap[Int, AnyRef]
override def update(index: Int, grad: NDArray, weight: NDArray, handle: AnyRef): Unit = {
val state = states.getOrElseUpdate(index, optimizer.createState(index, weight))
optimizer.update(index, weight, grad, state)
}
}
}
}

abstract class Optimizer extends Serializable {
/**
* Update the parameters.
* @param index An unique integer key used to index the parameters
* @param weight weight ndarray
* @param grad grad ndarray
* @param state NDArray or other objects returned by initState
* The auxiliary state used in optimization.
*/
def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = ???

// Create additional optimizer state such as momentum.
def createState(index: Int, weight: NDArray): AnyRef
}

trait MXKVStoreUpdater {
Expand All @@ -13,5 +36,5 @@ trait MXKVStoreUpdater {
* @param local the value stored on local on this key
* @param handle The additional handle to the updater
*/
def update(key: Int, recv: NDArray, local: NDArray, handle: AnyRef): Unit
def update(key: Int, recv: NDArray, local: NDArray, handle: AnyRef = null): Unit
}
59 changes: 59 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Serializer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package ml.dmlc.mxnet

import java.io._
import java.nio.ByteBuffer
import java.nio.charset.Charset

import org.apache.commons.codec.binary.Base64

import scala.reflect.ClassTag

/**
* Serialize & deserialize Java/Scala [[Serializable]] objects
* @author Yizhi Liu
*/
abstract class Serializer {
def serialize[T: ClassTag](t: T): ByteBuffer
def deserialize[T: ClassTag](bytes: ByteBuffer): T
}

object Serializer {
val UTF8 = Charset.forName("UTF-8")

def getSerializer: Serializer = getSerializer(None)

def getSerializer(serializer: Serializer): Serializer = {
// TODO: dynamically get from mxnet env to support other serializers like Kyro
if (serializer == null) new JavaSerializer else serializer
}

def getSerializer(serializer: Option[Serializer]): Serializer = {
// TODO: dynamically get from mxnet env to support other serializers like Kyro
serializer.getOrElse(new JavaSerializer)
}

def encodeBase64String(bytes: ByteBuffer): String = {
new String(Base64.encodeBase64(bytes.array), UTF8)
}

def decodeBase64String(str: String): ByteBuffer = {
ByteBuffer.wrap(Base64.decodeBase64(str.getBytes(UTF8)))
}
}

class JavaSerializer extends Serializer {
override def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = new ObjectOutputStream(bos)
out.writeObject(t)
out.close()
ByteBuffer.wrap(bos.toByteArray)
}

override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
val byteArray = bytes.array()
val bis = new ByteArrayInputStream(byteArray)
val in = new ObjectInputStream(bis)
in.readObject().asInstanceOf[T]
}
}
11 changes: 11 additions & 0 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,15 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll {
kv.pull(3, ndArray)
assert(ndArray.toArray === Array(6f, 6f))
}

test("get type") {
val kv = KVStore.create("local")
assert(kv.`type` === "local")
}

test("get numWorkers and rank") {
val kv = KVStore.create("local")
assert(kv.numWorkers === 1)
assert(kv.rank === 0)
}
}
24 changes: 24 additions & 0 deletions scala-package/native/src/main/native/jni_helper_func.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include <jni.h>

#ifndef MXNET_SCALA_JNI_HELPER_FUNC_H
#define MXNET_SCALA_JNI_HELPER_FUNC_H

jlong getLongField(JNIEnv *env, jobject obj) {
jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong");
jfieldID refFid = env->GetFieldID(refClass, "value", "J");
return env->GetLongField(obj, refFid);
}

jint getIntField(JNIEnv *env, jobject obj) {
jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt");
jfieldID refFid = env->GetFieldID(refClass, "value", "I");
return env->GetIntField(obj, refFid);
}

void setIntField(JNIEnv *env, jobject obj, jint value) {
jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt");
jfieldID refFid = env->GetFieldID(refClass, "value", "I");
env->SetIntField(obj, refFid, value);
}

#endif
Loading

0 comments on commit e78cc31

Please sign in to comment.