diff --git a/.gitignore b/.gitignore index cd24d897adf0..319f601ba95e 100644 --- a/.gitignore +++ b/.gitignore @@ -97,4 +97,4 @@ scala-package/*/*/target/ *.iml *.classpath *.project -*.settings +*.settings \ No newline at end of file diff --git a/scala-package/core/scripts/get_mnist_data.sh b/scala-package/core/scripts/get_mnist_data.sh new file mode 100755 index 000000000000..e080144f6663 --- /dev/null +++ b/scala-package/core/scripts/get_mnist_data.sh @@ -0,0 +1,11 @@ +data_path="./data" +if [ ! -d "$data_path" ]; then + mkdir -p "$data_path" +fi + +mnist_data_path="./data/mnist.zip" +if [ ! -f "$mnist_data_path" ]; then + wget http://webdocs.cs.ualberta.ca/~bx3/data/mnist.zip -P $data_path + cd $data_path + unzip -u mnist.zip +fi diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index faa1cec9bc27..0d5c848afd79 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -15,9 +15,12 @@ object Base { type MXFloatRef = RefFloat type NDArrayHandle = RefLong type FunctionHandle = RefLong + type DataIterHandle = RefLong + type DataIterCreator = RefLong type KVStoreHandle = RefLong type ExecutorHandle = RefLong + System.loadLibrary("mxnet-scala") val _LIB = new LibInfo diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala new file mode 100644 index 000000000000..11dcfadbbcff --- /dev/null +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -0,0 +1,237 @@ +package ml.dmlc.mxnet + +import ml.dmlc.mxnet.Base._ +import org.slf4j.LoggerFactory + +import scala.collection.mutable.ListBuffer + +object IO { + type IterCreateFunc = (Map[String, String]) => DataIter + + private val logger = LoggerFactory.getLogger(classOf[DataIter]) + private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule() + + /** + * create iterator via iterName and params + * @param iterName name of iterator; "MNISTIter" or "ImageRecordIter" + * @param params paramters for create iterator + * @return + */ + def createIterator(iterName: String, params: Map[String, String]): DataIter = { + return iterCreateFuncs(iterName)(params) + } + + /** + * initi all IO creator Functions + * @return + */ + private def _initIOModule(): Map[String, IterCreateFunc] = { + val IterCreators = new ListBuffer[DataIterCreator] + checkCall(_LIB.mxListDataIters(IterCreators)) + IterCreators.map(_makeIOIterator).toMap + } + + private def _makeIOIterator(handle: DataIterCreator): (String, IterCreateFunc) = { + val name = new RefString + val desc = new RefString + val argNames = new ListBuffer[String] + val argTypes = new ListBuffer[String] + val argDescs = new ListBuffer[String] + checkCall(_LIB.mxDataIterGetIterInfo(handle, name, desc, argNames, argTypes, argDescs)) + val paramStr = Base.ctypes2docstring(argNames, argTypes, argDescs) + val docStr = s"${name.value}\n${desc.value}\n\n$paramStr\n" + logger.debug(docStr) + return (name.value, creator(handle)) + } + + /** + * + * @param handle + * @param params + * @return + */ + private def creator(handle: DataIterCreator)( + params: Map[String, String]): DataIter = { + val out = new DataIterHandle + val keys = params.keys.toArray + val vals = params.values.toArray + checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out)) + return new MXDataIter(out) + } +} + + +/** + * class batch of data + * @param data + * @param label + * @param index + * @param pad + */ +case class DataBatch(val data: NDArray, + val label: NDArray, + val index: List[Long], + val pad: Int) + +/** + *DataIter object in mxnet. + */ +abstract class DataIter (val batchSize: Int = 0) { + /** + * reset the iterator + */ + def reset(): Unit + /** + * Iterate to next batch + * @return whether the move is successful + */ + def iterNext(): Boolean + + /** + * get next data batch from iterator + * @return + */ + def next(): DataBatch = { + return new DataBatch(getData(), getLabel(), getIndex(), getPad()) + } + + /** + * get data of current batch + * @return the data of current batch + */ + def getData(): NDArray + + /** + * Get label of current batch + * @return the label of current batch + */ + def getLabel(): NDArray + + /** + * get the number of padding examples + * in current batch + * @return number of padding examples in current batch + */ + def getPad(): Int + + /** + * the index of current batch + * @return + */ + def getIndex(): List[Long] + +} + +/** + * DataIter built in MXNet. + * @param handle the handle to the underlying C++ Data Iterator + */ +class MXDataIter(val handle: DataIterHandle) extends DataIter { + private val logger = LoggerFactory.getLogger(classOf[MXDataIter]) + + override def finalize() = { + checkCall(_LIB.mxDataIterFree(handle)) + } + + /** + * reset the iterator + */ + override def reset(): Unit = { + checkCall(_LIB.mxDataIterBeforeFirst(handle)) + } + + /** + * Iterate to next batch + * @return whether the move is successful + */ + override def iterNext(): Boolean = { + val next = new RefInt + checkCall(_LIB.mxDataIterNext(handle, next)) + return next.value > 0 + } + + /** + * get data of current batch + * @return the data of current batch + */ + override def getData(): NDArray = { + val out = new NDArrayHandle + checkCall(_LIB.mxDataIterGetData(handle, out)) + return new NDArray(out, writable = false) + } + + /** + * Get label of current batch + * @return the label of current batch + */ + override def getLabel(): NDArray = { + val out = new NDArrayHandle + checkCall(_LIB.mxDataIterGetLabel(handle, out)) + return new NDArray(out, writable = false) + } + + /** + * the index of current batch + * @return + */ + override def getIndex(): List[Long] = { + val outIndex = new ListBuffer[Long] + val outSize = new RefLong + checkCall(_LIB.mxDataIterGetIndex(handle, outIndex, outSize)) + return outIndex.toList + } + + /** + * get the number of padding examples + * in current batch + * @return number of padding examples in current batch + */ + override def getPad(): MXUint = { + val out = new MXUintRef + checkCall(_LIB.mxDataIterGetPadNum(handle, out)) + return out.value + } +} + +/** + * To do + */ +class ArrayDataIter() extends DataIter { + /** + * reset the iterator + */ + override def reset(): Unit = ??? + + /** + * get data of current batch + * @return the data of current batch + */ + override def getData(): NDArray = ??? + + /** + * Get label of current batch + * @return the label of current batch + */ + override def getLabel(): NDArray = ??? + + /** + * the index of current batch + * @return + */ + override def getIndex(): List[Long] = ??? + + /** + * Iterate to next batch + * @return whether the move is successful + */ + override def iterNext(): Boolean = ??? + + /** + * get the number of padding examples + * in current batch + * @return number of padding examples in current batch + */ + override def getPad(): MXUint = ??? +} + + diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 715d39fc67a8..cdc69d857c8f 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -84,6 +84,32 @@ class LibInfo { @native def mxKVStoreBarrier(handle: KVStoreHandle): Int @native def mxKVStoreGetGroupSize(handle: KVStoreHandle, size: RefInt): Int @native def mxKVStoreGetRank(handle: KVStoreHandle, size: RefInt): Int + + //DataIter Funcs + @native def mxListDataIters(handles: ListBuffer[DataIterCreator]): Int + @native def mxDataIterCreateIter(handle: DataIterCreator, + keys: Array[String], + vals: Array[String], + out: DataIterHandle): Int + @native def mxDataIterGetIterInfo(creator: DataIterCreator, + name: RefString, + description: RefString, + argNames: ListBuffer[String], + argTypeInfos: ListBuffer[String], + argDescriptions: ListBuffer[String]): Int + @native def mxDataIterFree(handle: DataIterHandle): Int + @native def mxDataIterBeforeFirst(handle: DataIterHandle): Int + @native def mxDataIterNext(handle: DataIterHandle, out: RefInt): Int + @native def mxDataIterGetLabel(handle: DataIterHandle, + out: NDArrayHandle): Int + @native def mxDataIterGetData(handle: DataIterHandle, + out: NDArrayHandle): Int + @native def mxDataIterGetIndex(handle: DataIterHandle, + outIndex: ListBuffer[Long], + outSize: RefLong): Int + @native def mxDataIterGetPadNum(handle: DataIterHandle, + out: MXUintRef): Int + //Executors @native def mxExecutorOutputs(handle: ExecutorHandle, outputs: ArrayBuffer[NDArrayHandle]): Int @native def mxExecutorFree(handle: ExecutorHandle): Int @native def mxExecutorForward(handle: ExecutorHandle, isTrain: Int): Int diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala new file mode 100644 index 000000000000..fc6940cec50b --- /dev/null +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala @@ -0,0 +1,76 @@ +package ml.dmlc.mxnet + +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import scala.sys.process._ + + +class IOSuite extends FunSuite with BeforeAndAfterAll { + test("test MNISTIter") { + //get data + "./scripts/get_mnist_data.sh" ! + + val params = Map( + "image" -> "data/train-images-idx3-ubyte", + "label" -> "data/train-labels-idx1-ubyte", + "data_shape" -> "(784,)", + "batch_size" -> "100", + "shuffle" -> "1", + "flat" -> "1", + "silent" -> "0", + "seed" -> "10" + ) + + val mnistIter = IO.createIterator("MNISTIter", params) + //test_loop + mnistIter.reset() + val nBatch = 600 + var batchCount = 0 + while(mnistIter.iterNext()) { + val batch = mnistIter.next() + batchCount+=1 + } + //test loop + assert(nBatch === batchCount) + //test reset + mnistIter.reset() + mnistIter.iterNext() + val label0 = mnistIter.getLabel().toArray + mnistIter.iterNext() + mnistIter.iterNext() + mnistIter.iterNext() + mnistIter.reset() + mnistIter.iterNext() + val label1 = mnistIter.getLabel().toArray + assert(label0 === label1) + } + + + /** + * not work now + */ +// test("test ImageRecordIter") { +// //get data +// //"./scripts/get_cifar_data.sh" ! +// +// val params = Map( +// "path_imgrec" -> "data/cifar/train.rec", +// "mean_img" -> "data/cifar/cifar10_mean.bin", +// "rand_crop" -> "False", +// "and_mirror" -> "False", +// "shuffle" -> "False", +// "data_shape" -> "(3,28,28)", +// "batch_size" -> "100", +// "preprocess_threads" -> "4", +// "prefetch_buffer" -> "1" +// ) +// val img_iter = IO.createIterator("ImageRecordIter", params) +// img_iter.reset() +// while(img_iter.iterNext()) { +// val batch = img_iter.next() +// } +// } + +// test("test NDarryIter") { +// +// } +} diff --git a/scala-package/native/linux-x86_64/pom.xml b/scala-package/native/linux-x86_64/pom.xml index 398cf01219bd..a0dedfb8122b 100644 --- a/scala-package/native/linux-x86_64/pom.xml +++ b/scala-package/native/linux-x86_64/pom.xml @@ -13,6 +13,10 @@ MXNet Scala Package - Native Linux-x86_64 http://maven.apache.org + + opencv.pkg.txt + + so @@ -74,6 +78,12 @@ + + + + + + @@ -103,7 +113,7 @@ - -msse3 -funroll-loops -Wno-unused-parameter -Wno-unknown-pragmas + -msse3 -funroll-loops -Wno-unused-parameter -Wno-unknown-pragmas -fopenmp @@ -114,15 +124,19 @@ -DMSHADOW_USE_CBLAS=${use.cblas} -DMSHADOW_USE_MKL=${use.mkl} -fPIC + -shared - -fopenmp ${ldflags.blas} -Wl,--whole-archive ../../../lib/libmxnet.a - -Wl,-no-whole-archive + ../../../dmlc-core/libdmlc.a + -Wl,--no-whole-archive + -lm -lrt -fopenmp + ${ldflags.opencv} + ${ldflags.blas} diff --git a/scala-package/native/src/main/native/jni_helper_func.h b/scala-package/native/src/main/native/jni_helper_func.h index 43668bc1df9d..cce9cb0efe22 100644 --- a/scala-package/native/src/main/native/jni_helper_func.h +++ b/scala-package/native/src/main/native/jni_helper_func.h @@ -21,10 +21,15 @@ void setIntField(JNIEnv *env, jobject obj, jint value) { env->SetIntField(obj, refFid, value); } +void setLongField(JNIEnv *env, jobject obj, jlong value) { + jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); + jfieldID refFid = env->GetFieldID(refClass, "value", "J"); + env->SetLongField(obj, refFid, value); +} + void setStringField(JNIEnv *env, jobject obj, const char *value) { jclass refClass = env->FindClass("ml/dmlc/mxnet/Base$RefString"); jfieldID refFid = env->GetFieldID(refClass, "value", "Ljava/lang/String;"); env->SetObjectField(obj, refFid, env->NewStringUTF(value)); } - #endif diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index c387f24e1350..ade79168a9a7 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -489,3 +489,180 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxNDArrayFree(JNIEnv * env, jo puts("Free ndarray called"); return 0; } + +//IO funcs +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxListDataIters + (JNIEnv * env, jobject obj, jobject creators) { + // Base.FunctionHandle.constructor + jclass chClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); + jmethodID chConstructor = env->GetMethodID(chClass,"","(J)V"); + + // scala.collection.mutable.ListBuffer append method + jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); + jmethodID listAppend = env->GetMethodID(listClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); + + // Get function list + DataIterCreator *outArray; + mx_uint outSize; + int ret = MXListDataIters(&outSize, &outArray); + for (int i = 0; i < outSize; ++i) { + DataIterCreator chAddr = outArray[i]; + jobject chObj = env->NewObject(chClass, chConstructor, (long)chAddr); + env->CallObjectMethod(creators, listAppend, chObj); + } + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterCreateIter + (JNIEnv * env, jobject obj, jobject creator, + jobjectArray jkeys, jobjectArray jvals, jobject dataIterHandle) { + //keys and values + int paramSize = env->GetArrayLength(jkeys); + char** keys = new char*[paramSize]; + char** vals = new char*[paramSize]; + jstring jkey, jval; + //use strcpy and release char* created by JNI inplace + for(int i=0; iGetObjectArrayElement(jkeys, i); + const char* ckey = env->GetStringUTFChars(jkey, 0); + keys[i] = new char[env->GetStringLength(jkey)]; + strcpy(keys[i], ckey); + env->ReleaseStringUTFChars(jkey, ckey); + + jval = (jstring) env->GetObjectArrayElement(jvals, i); + const char* cval = env->GetStringUTFChars(jval, 0); + vals[i] = new char[env->GetStringLength(jval)]; + strcpy(vals[i], cval); + env->ReleaseStringUTFChars(jval, cval); + } + + //create iter + jlong creatorPtr = getLongField(env, creator); + DataIterHandle out; + int ret = MXDataIterCreateIter((DataIterCreator)creatorPtr, + (mx_uint) paramSize, + (const char**) keys, + (const char**) vals, + &out); + jclass hClass = env->GetObjectClass(dataIterHandle); + jfieldID ptr = env->GetFieldID(hClass, "value", "J"); + env->SetLongField(dataIterHandle, ptr, (long)out); + + //release keys and vals + for(int i=0; iFindClass("ml/dmlc/mxnet/Base$RefString"); + jfieldID valueStr = env->GetFieldID(refStringClass, "value", "Ljava/lang/String;"); + //set params + env->SetObjectField(jname, valueStr, env->NewStringUTF(name)); + env->SetObjectField(jdesc, valueStr, env->NewStringUTF(description)); + jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); + jmethodID listAppend = env->GetMethodID(listClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); + for(int i=0; iCallObjectMethod(jargNames, listAppend, env->NewStringUTF(argNames[i])); + env->CallObjectMethod(jargTypeInfos, listAppend, env->NewStringUTF(argTypeInfos[i])); + env->CallObjectMethod(jargDescs, listAppend, env->NewStringUTF(argDescs[i])); + } + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterFree + (JNIEnv *env, jobject obj, jobject handle) { + jlong handlePtr = getLongField(env, handle); + int ret = MXDataIterFree((DataIterHandle) handlePtr); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterBeforeFirst + (JNIEnv *env, jobject obj, jobject handle) { + jlong handlePtr = getLongField(env, handle); + int ret = MXDataIterBeforeFirst((DataIterHandle) handlePtr); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterNext + (JNIEnv *env, jobject obj, jobject handle, jobject out) { + jlong handlePtr = getLongField(env, handle); + int cout; + int ret = MXDataIterNext((DataIterHandle)handlePtr, &cout); + setIntField(env, out, cout); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetLabel + (JNIEnv *env, jobject obj, jobject handle, jobject ndArrayHandle) { + jlong handlePtr = getLongField(env, handle); + NDArrayHandle out; + int ret = MXDataIterGetLabel((DataIterHandle)handlePtr, &out); + jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); + jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); + env->SetLongField(ndArrayHandle, refLongFid, (jlong)out); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetData + (JNIEnv *env, jobject obj, jobject handle, jobject ndArrayHandle) { + jlong handlePtr = getLongField(env, handle); + NDArrayHandle out; + int ret = MXDataIterGetData((DataIterHandle)handlePtr, &out); + jclass refLongClass = env->FindClass("ml/dmlc/mxnet/Base$RefLong"); + jfieldID refLongFid = env->GetFieldID(refLongClass, "value", "J"); + env->SetLongField(ndArrayHandle, refLongFid, (jlong)out); + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIndex + (JNIEnv *env, jobject obj, jobject handle, jobject outIndex, jobject outSize) { + jlong handlePtr = getLongField(env, handle); + uint64_t* coutIndex; + uint64_t coutSize; + int ret = MXDataIterGetIndex((DataIterHandle)handlePtr, &coutIndex, &coutSize); + //set field + setLongField(env, outSize, (long)coutSize); + // scala.collection.mutable.ListBuffer append method + jclass listClass = env->FindClass("scala/collection/mutable/ListBuffer"); + jmethodID listAppend = env->GetMethodID(listClass, + "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;"); + + for(int i=0; iCallObjectMethod(outIndex, listAppend, (jlong)coutIndex[i]); + } + return ret; +} + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetPadNum + (JNIEnv *env, jobject obj, jobject handle, jobject pad) { + jlong handlePtr = getLongField(env, handle); + int cpad; + int ret = MXDataIterGetPadNum((DataIterHandle)handlePtr, &cpad); + setIntField(env, pad, cpad); + return ret; +} diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index f3919a06fa20..43ed4cc922a0 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -131,12 +131,6 @@ if [ ${TASK} == "scala_test" ]; then mvn integration-test -P osx-x86_64 --log-file scala_test_results.txt fi if [ ${TRAVIS_OS_NAME} == "linux" ]; then - # (Yizhi Liu) I'm not sure it is a proper solution, - # which is mentioned here: - # http://stackoverflow.com/questions/9558909/jni-symbol-lookup-error-in-shared-library-on-linux/13086028#13086028 - # I really don't know why we have to export LD_PRELOAD - # to make libblas loaded in travis. It just works. - export LD_PRELOAD=/usr/lib/libblas/libblas.so # use g++-4.8 for linux mvn clean package -P linux-x86_64 -D cxx=g++-4.8 -D ldflags.blas=-lblas mvn integration-test -P linux-x86_64 -D cxx=g++-4.8 -D ldflags.blas=-lblas --log-file scala_test_results.txt