Skip to content

Commit

Permalink
Merge pull request apache#24 from yanqingmen/scala
Browse files Browse the repository at this point in the history
small modification for io
  • Loading branch information
yzhliu committed Jan 28, 2016
2 parents 9979ca5 + d86cc42 commit 9a60d20
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 16 deletions.
11 changes: 11 additions & 0 deletions scala-package/core/scripts/get_cifar_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
data_path="./data"
if [ ! -d "$data_path" ]; then
mkdir -p "$data_path"
fi

cifar_data_path="./data/cifar10.zip"
if [ ! -f "$cifar_data_path" ]; then
wget http://webdocs.cs.ualberta.ca/~bx3/data/cifar10.zip -P $data_path
cd $data_path
unzip -u cifar10.zip
fi
35 changes: 28 additions & 7 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ object IO {
private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule()

def MNISTIter: IterCreateFunc = iterCreateFuncs("MNISTIter")
def ImageRecordIter: IterCreateFunc = iterCreateFuncs("ImageRecordIter")
def CSVIter: IterCreateFunc = iterCreateFuncs("CSVIter")

/**
* create iterator via iterName and params
Expand Down Expand Up @@ -66,7 +68,7 @@ object IO {
}

// Convert data into canonical form.
private def initData(data: NDArray, allowEmpty: Boolean, defaultName: String) = {
private def initData(data: List[NDArray], allowEmpty: Boolean, defaultName: String) = {
require(data != null || allowEmpty)
// TODO
}
Expand Down Expand Up @@ -145,9 +147,17 @@ abstract class DataIter(val batchSize: Int = 0) {
* @param handle the handle to the underlying C++ Data Iterator
*/
// scalastyle:off finalize
class MXDataIter(val handle: DataIterHandle) extends DataIter {
class MXDataIter(val handle: DataIterHandle,
val dataName: String = "data",
val labelName: String = "label") extends DataIter {
private val logger = LoggerFactory.getLogger(classOf[MXDataIter])

// get first batch
reset()
iterNext()
val firstBatch = next()
reset()

override def finalize(): Unit = {
checkCall(_LIB.mxDataIterFree(handle))
}
Expand Down Expand Up @@ -212,17 +222,28 @@ class MXDataIter(val handle: DataIterHandle) extends DataIter {
}

// The name and shape of data provided by this iterator
override def provideData: Map[String, Shape] = ???
override def provideData: Map[String, Shape] = {
Map(dataName -> firstBatch.data.head.shape)
}

// The name and shape of label provided by this iterator
override def provideLabel: Map[String, Shape] = ???
override def provideLabel: Map[String, Shape] = {
Map(labelName -> firstBatch.label.head.shape)
}
}
// scalastyle:on finalize

/**
* TODO
*/
class ArrayDataIter() extends DataIter {
* Base class for prefetching iterators. Takes one or more DataIters
* (or any class with "reset" and "read" methods) and combine them with
* prefetching.
* @param iters list of DataIters
* @param dataNames
* @param labelNames
*/
class PrefetchingIter(val iters: List[DataIter],
val dataNames: Map[String, String] = null,
val labelNames: Map[String, String] = null) extends DataIter {
/**
* reset the iterator
*/
Expand Down
46 changes: 38 additions & 8 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,18 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
)

val mnistIter = IO.MNISTIter(params)
// test provideData
val provideData = mnistIter.provideData
val provideLabel = mnistIter.provideLabel
assert(provideData("data") === Array(100, 784))
assert(provideLabel("label") === Array(100))
// test_loop
mnistIter.reset()
val nBatch = 600
var batchCount = 0
while(mnistIter.iterNext()) {
val batch = mnistIter.next()
batchCount+=1
batchCount += 1
}
// test loop
assert(nBatch === batchCount)
Expand All @@ -49,11 +54,11 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {


/**
* not work now
* default skip this test for saving time
*/
// test("test ImageRecordIter") {
// //get data
// //"./scripts/get_cifar_data.sh" !
// // get data
// "./scripts/get_cifar_data.sh" !
//
// val params = Map(
// "path_imgrec" -> "data/cifar/train.rec",
Expand All @@ -66,11 +71,36 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
// "preprocess_threads" -> "4",
// "prefetch_buffer" -> "1"
// )
// val img_iter = IO.createIterator("ImageRecordIter", params)
// img_iter.reset()
// while(img_iter.iterNext()) {
// val batch = img_iter.next()
// val imgRecIter = IO.ImageRecordIter(params)
// val nBatch = 500
// var batchCount = 0
// // test provideData
// val provideData = imgRecIter.provideData
// val provideLabel = imgRecIter.provideLabel
// assert(provideData("data") === Array(100, 3, 28, 28))
// assert(provideLabel("label") === Array(100))
//
// imgRecIter.reset()
// while(imgRecIter.iterNext()) {
// val batch = imgRecIter.next()
// batchCount += 1
// }
// // test loop
// assert(batchCount === nBatch)
// // test reset
// imgRecIter.reset()
// imgRecIter.iterNext()
// val label0 = imgRecIter.getLabel().head.toArray
// val data0 = imgRecIter.getData().head.toArray
// imgRecIter.iterNext()
// imgRecIter.iterNext()
// imgRecIter.iterNext()
// imgRecIter.reset()
// imgRecIter.iterNext()
// val label1 = imgRecIter.getLabel().head.toArray
// val data1 = imgRecIter.getData().head.toArray
// assert(label0 === label1)
// assert(data0 === data1)
// }

// test("test NDarryIter") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,8 +648,13 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxDataIterGetIndex
jmethodID listAppend = env->GetMethodID(listClass,
"$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ListBuffer;");

//long class
jclass longCls = env->FindClass("java/lang/Long");
jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");

for(int i=0; i<coutSize; i++) {
env->CallObjectMethod(outIndex, listAppend, (jlong)coutIndex[i]);
env->CallObjectMethod(outIndex, listAppend,
env->NewObject(longCls, longConst, coutIndex[i]));
}
return ret;
}
Expand Down

0 comments on commit 9a60d20

Please sign in to comment.