Skip to content

Commit

Permalink
Merge pull request apache#16 from yanqingmen/scala
Browse files Browse the repository at this point in the history
mxnet io package
  • Loading branch information
yzhliu committed Dec 29, 2015
2 parents aea7c13 + d91c39c commit e74d226
Show file tree
Hide file tree
Showing 10 changed files with 554 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,4 @@ scala-package/*/*/target/
*.iml
*.classpath
*.project
*.settings
*.settings
11 changes: 11 additions & 0 deletions scala-package/core/scripts/get_mnist_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
data_path="./data"
if [ ! -d "$data_path" ]; then
mkdir -p "$data_path"
fi

mnist_data_path="./data/mnist.zip"
if [ ! -f "$mnist_data_path" ]; then
wget http://webdocs.cs.ualberta.ca/~bx3/data/mnist.zip -P $data_path
cd $data_path
unzip -u mnist.zip
fi
3 changes: 3 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ object Base {
type MXFloatRef = RefFloat
type NDArrayHandle = RefLong
type FunctionHandle = RefLong
type DataIterHandle = RefLong
type DataIterCreator = RefLong
type KVStoreHandle = RefLong
type ExecutorHandle = RefLong


System.loadLibrary("mxnet-scala")
val _LIB = new LibInfo

Expand Down
237 changes: 237 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
package ml.dmlc.mxnet

import ml.dmlc.mxnet.Base._
import org.slf4j.LoggerFactory

import scala.collection.mutable.ListBuffer

object IO {
type IterCreateFunc = (Map[String, String]) => DataIter

private val logger = LoggerFactory.getLogger(classOf[DataIter])
private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule()

/**
* create iterator via iterName and params
* @param iterName name of iterator; "MNISTIter" or "ImageRecordIter"
* @param params paramters for create iterator
* @return
*/
def createIterator(iterName: String, params: Map[String, String]): DataIter = {
return iterCreateFuncs(iterName)(params)
}

/**
* initi all IO creator Functions
* @return
*/
private def _initIOModule(): Map[String, IterCreateFunc] = {
val IterCreators = new ListBuffer[DataIterCreator]
checkCall(_LIB.mxListDataIters(IterCreators))
IterCreators.map(_makeIOIterator).toMap
}

private def _makeIOIterator(handle: DataIterCreator): (String, IterCreateFunc) = {
val name = new RefString
val desc = new RefString
val argNames = new ListBuffer[String]
val argTypes = new ListBuffer[String]
val argDescs = new ListBuffer[String]
checkCall(_LIB.mxDataIterGetIterInfo(handle, name, desc, argNames, argTypes, argDescs))
val paramStr = Base.ctypes2docstring(argNames, argTypes, argDescs)
val docStr = s"${name.value}\n${desc.value}\n\n$paramStr\n"
logger.debug(docStr)
return (name.value, creator(handle))
}

/**
*
* @param handle
* @param params
* @return
*/
private def creator(handle: DataIterCreator)(
params: Map[String, String]): DataIter = {
val out = new DataIterHandle
val keys = params.keys.toArray
val vals = params.values.toArray
checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out))
return new MXDataIter(out)
}
}


/**
* class batch of data
* @param data
* @param label
* @param index
* @param pad
*/
case class DataBatch(val data: NDArray,
val label: NDArray,
val index: List[Long],
val pad: Int)

/**
*DataIter object in mxnet.
*/
abstract class DataIter (val batchSize: Int = 0) {
/**
* reset the iterator
*/
def reset(): Unit
/**
* Iterate to next batch
* @return whether the move is successful
*/
def iterNext(): Boolean

/**
* get next data batch from iterator
* @return
*/
def next(): DataBatch = {
return new DataBatch(getData(), getLabel(), getIndex(), getPad())
}

/**
* get data of current batch
* @return the data of current batch
*/
def getData(): NDArray

/**
* Get label of current batch
* @return the label of current batch
*/
def getLabel(): NDArray

/**
* get the number of padding examples
* in current batch
* @return number of padding examples in current batch
*/
def getPad(): Int

/**
* the index of current batch
* @return
*/
def getIndex(): List[Long]

}

/**
* DataIter built in MXNet.
* @param handle the handle to the underlying C++ Data Iterator
*/
class MXDataIter(val handle: DataIterHandle) extends DataIter {
private val logger = LoggerFactory.getLogger(classOf[MXDataIter])

override def finalize() = {
checkCall(_LIB.mxDataIterFree(handle))
}

/**
* reset the iterator
*/
override def reset(): Unit = {
checkCall(_LIB.mxDataIterBeforeFirst(handle))
}

/**
* Iterate to next batch
* @return whether the move is successful
*/
override def iterNext(): Boolean = {
val next = new RefInt
checkCall(_LIB.mxDataIterNext(handle, next))
return next.value > 0
}

/**
* get data of current batch
* @return the data of current batch
*/
override def getData(): NDArray = {
val out = new NDArrayHandle
checkCall(_LIB.mxDataIterGetData(handle, out))
return new NDArray(out, writable = false)
}

/**
* Get label of current batch
* @return the label of current batch
*/
override def getLabel(): NDArray = {
val out = new NDArrayHandle
checkCall(_LIB.mxDataIterGetLabel(handle, out))
return new NDArray(out, writable = false)
}

/**
* the index of current batch
* @return
*/
override def getIndex(): List[Long] = {
val outIndex = new ListBuffer[Long]
val outSize = new RefLong
checkCall(_LIB.mxDataIterGetIndex(handle, outIndex, outSize))
return outIndex.toList
}

/**
* get the number of padding examples
* in current batch
* @return number of padding examples in current batch
*/
override def getPad(): MXUint = {
val out = new MXUintRef
checkCall(_LIB.mxDataIterGetPadNum(handle, out))
return out.value
}
}

/**
* To do
*/
class ArrayDataIter() extends DataIter {
/**
* reset the iterator
*/
override def reset(): Unit = ???

/**
* get data of current batch
* @return the data of current batch
*/
override def getData(): NDArray = ???

/**
* Get label of current batch
* @return the label of current batch
*/
override def getLabel(): NDArray = ???

/**
* the index of current batch
* @return
*/
override def getIndex(): List[Long] = ???

/**
* Iterate to next batch
* @return whether the move is successful
*/
override def iterNext(): Boolean = ???

/**
* get the number of padding examples
* in current batch
* @return number of padding examples in current batch
*/
override def getPad(): MXUint = ???
}


26 changes: 26 additions & 0 deletions scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,32 @@ class LibInfo {
@native def mxKVStoreBarrier(handle: KVStoreHandle): Int
@native def mxKVStoreGetGroupSize(handle: KVStoreHandle, size: RefInt): Int
@native def mxKVStoreGetRank(handle: KVStoreHandle, size: RefInt): Int

//DataIter Funcs
@native def mxListDataIters(handles: ListBuffer[DataIterCreator]): Int
@native def mxDataIterCreateIter(handle: DataIterCreator,
keys: Array[String],
vals: Array[String],
out: DataIterHandle): Int
@native def mxDataIterGetIterInfo(creator: DataIterCreator,
name: RefString,
description: RefString,
argNames: ListBuffer[String],
argTypeInfos: ListBuffer[String],
argDescriptions: ListBuffer[String]): Int
@native def mxDataIterFree(handle: DataIterHandle): Int
@native def mxDataIterBeforeFirst(handle: DataIterHandle): Int
@native def mxDataIterNext(handle: DataIterHandle, out: RefInt): Int
@native def mxDataIterGetLabel(handle: DataIterHandle,
out: NDArrayHandle): Int
@native def mxDataIterGetData(handle: DataIterHandle,
out: NDArrayHandle): Int
@native def mxDataIterGetIndex(handle: DataIterHandle,
outIndex: ListBuffer[Long],
outSize: RefLong): Int
@native def mxDataIterGetPadNum(handle: DataIterHandle,
out: MXUintRef): Int
//Executors
@native def mxExecutorOutputs(handle: ExecutorHandle, outputs: ArrayBuffer[NDArrayHandle]): Int
@native def mxExecutorFree(handle: ExecutorHandle): Int
@native def mxExecutorForward(handle: ExecutorHandle, isTrain: Int): Int
Expand Down
76 changes: 76 additions & 0 deletions scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package ml.dmlc.mxnet

import org.scalatest.{BeforeAndAfterAll, FunSuite}
import scala.sys.process._


class IOSuite extends FunSuite with BeforeAndAfterAll {
test("test MNISTIter") {
//get data
"./scripts/get_mnist_data.sh" !

val params = Map(
"image" -> "data/train-images-idx3-ubyte",
"label" -> "data/train-labels-idx1-ubyte",
"data_shape" -> "(784,)",
"batch_size" -> "100",
"shuffle" -> "1",
"flat" -> "1",
"silent" -> "0",
"seed" -> "10"
)

val mnistIter = IO.createIterator("MNISTIter", params)
//test_loop
mnistIter.reset()
val nBatch = 600
var batchCount = 0
while(mnistIter.iterNext()) {
val batch = mnistIter.next()
batchCount+=1
}
//test loop
assert(nBatch === batchCount)
//test reset
mnistIter.reset()
mnistIter.iterNext()
val label0 = mnistIter.getLabel().toArray
mnistIter.iterNext()
mnistIter.iterNext()
mnistIter.iterNext()
mnistIter.reset()
mnistIter.iterNext()
val label1 = mnistIter.getLabel().toArray
assert(label0 === label1)
}


/**
* not work now
*/
// test("test ImageRecordIter") {
// //get data
// //"./scripts/get_cifar_data.sh" !
//
// val params = Map(
// "path_imgrec" -> "data/cifar/train.rec",
// "mean_img" -> "data/cifar/cifar10_mean.bin",
// "rand_crop" -> "False",
// "and_mirror" -> "False",
// "shuffle" -> "False",
// "data_shape" -> "(3,28,28)",
// "batch_size" -> "100",
// "preprocess_threads" -> "4",
// "prefetch_buffer" -> "1"
// )
// val img_iter = IO.createIterator("ImageRecordIter", params)
// img_iter.reset()
// while(img_iter.iterNext()) {
// val batch = img_iter.next()
// }
// }

// test("test NDarryIter") {
//
// }
}
Loading

0 comments on commit e74d226

Please sign in to comment.