forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
KVStore setOptimizer, with JavaSerializer
- Loading branch information
Showing
6 changed files
with
176 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
58 changes: 58 additions & 0 deletions
58
scala-package/core/src/main/scala/ml/dmlc/mxnet/Serializer.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
package ml.dmlc.mxnet | ||
|
||
import java.io._ | ||
import java.nio.ByteBuffer | ||
import java.nio.charset.Charset | ||
import java.util.Base64 | ||
|
||
import scala.reflect.ClassTag | ||
|
||
/** | ||
* Serialize & deserialize Java/Scala [[Serializable]] objects | ||
* @author Yizhi Liu | ||
*/ | ||
abstract class Serializer { | ||
def serialize[T: ClassTag](t: T): ByteBuffer | ||
def deserialize[T: ClassTag](bytes: ByteBuffer): T | ||
} | ||
|
||
object Serializer { | ||
val UTF8 = Charset.forName("UTF-8") | ||
|
||
def getSerializer: Serializer = getSerializer(None) | ||
|
||
def getSerializer(serializer: Serializer): Serializer = { | ||
// TODO: dynamically get from mxnet env to support other serializers like Kyro | ||
if (serializer == null) new JavaSerializer else serializer | ||
} | ||
|
||
def getSerializer(serializer: Option[Serializer]): Serializer = { | ||
// TODO: dynamically get from mxnet env to support other serializers like Kyro | ||
serializer.getOrElse(new JavaSerializer) | ||
} | ||
|
||
def encodeBase64String(bytes: ByteBuffer): String = { | ||
new String(Base64.getEncoder.encode(bytes).array, UTF8) | ||
} | ||
|
||
def decodeBase64String(str: String): ByteBuffer = { | ||
ByteBuffer.wrap(Base64.getDecoder.decode(str.getBytes(UTF8))) | ||
} | ||
} | ||
|
||
class JavaSerializer extends Serializer { | ||
override def serialize[T: ClassTag](t: T): ByteBuffer = { | ||
val bos = new ByteArrayOutputStream() | ||
val out = new ObjectOutputStream(bos) | ||
out.writeObject(t) | ||
out.close() | ||
ByteBuffer.wrap(bos.toByteArray) | ||
} | ||
|
||
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { | ||
val byteArray = bytes.array() | ||
val bis = new ByteArrayInputStream(byteArray) | ||
val in = new ObjectInputStream(bis) | ||
in.readObject().asInstanceOf[T] | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#include <jni.h> | ||
|
||
#ifndef MXNET_SCALA_JNI_HELPER_FUNC_H | ||
#define MXNET_SCALA_JNI_HELPER_FUNC_H | ||
|
||
jlong getLongField(JNIEnv *env, jobject obj) { | ||
jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); | ||
jfieldID refFid = env->GetFieldID(refClass, "value", "J"); | ||
return env->GetLongField(obj, refFid); | ||
} | ||
|
||
jint getIntField(JNIEnv *env, jobject obj) { | ||
jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefInt"); | ||
jfieldID refFid = env->GetFieldID(refClass, "value", "I"); | ||
return env->GetIntField(obj, refFid); | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters