forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request apache#16 from yanqingmen/scala
mxnet io package
- Loading branch information
Showing
10 changed files
with
554 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,4 +97,4 @@ scala-package/*/*/target/ | |
*.iml | ||
*.classpath | ||
*.project | ||
*.settings | ||
*.settings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
data_path="./data" | ||
if [ ! -d "$data_path" ]; then | ||
mkdir -p "$data_path" | ||
fi | ||
|
||
mnist_data_path="./data/mnist.zip" | ||
if [ ! -f "$mnist_data_path" ]; then | ||
wget http://webdocs.cs.ualberta.ca/~bx3/data/mnist.zip -P $data_path | ||
cd $data_path | ||
unzip -u mnist.zip | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
237 changes: 237 additions & 0 deletions
237
scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
package ml.dmlc.mxnet | ||
|
||
import ml.dmlc.mxnet.Base._ | ||
import org.slf4j.LoggerFactory | ||
|
||
import scala.collection.mutable.ListBuffer | ||
|
||
object IO { | ||
type IterCreateFunc = (Map[String, String]) => DataIter | ||
|
||
private val logger = LoggerFactory.getLogger(classOf[DataIter]) | ||
private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule() | ||
|
||
/** | ||
* create iterator via iterName and params | ||
* @param iterName name of iterator; "MNISTIter" or "ImageRecordIter" | ||
* @param params paramters for create iterator | ||
* @return | ||
*/ | ||
def createIterator(iterName: String, params: Map[String, String]): DataIter = { | ||
return iterCreateFuncs(iterName)(params) | ||
} | ||
|
||
/** | ||
* initi all IO creator Functions | ||
* @return | ||
*/ | ||
private def _initIOModule(): Map[String, IterCreateFunc] = { | ||
val IterCreators = new ListBuffer[DataIterCreator] | ||
checkCall(_LIB.mxListDataIters(IterCreators)) | ||
IterCreators.map(_makeIOIterator).toMap | ||
} | ||
|
||
private def _makeIOIterator(handle: DataIterCreator): (String, IterCreateFunc) = { | ||
val name = new RefString | ||
val desc = new RefString | ||
val argNames = new ListBuffer[String] | ||
val argTypes = new ListBuffer[String] | ||
val argDescs = new ListBuffer[String] | ||
checkCall(_LIB.mxDataIterGetIterInfo(handle, name, desc, argNames, argTypes, argDescs)) | ||
val paramStr = Base.ctypes2docstring(argNames, argTypes, argDescs) | ||
val docStr = s"${name.value}\n${desc.value}\n\n$paramStr\n" | ||
logger.debug(docStr) | ||
return (name.value, creator(handle)) | ||
} | ||
|
||
/** | ||
* | ||
* @param handle | ||
* @param params | ||
* @return | ||
*/ | ||
private def creator(handle: DataIterCreator)( | ||
params: Map[String, String]): DataIter = { | ||
val out = new DataIterHandle | ||
val keys = params.keys.toArray | ||
val vals = params.values.toArray | ||
checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out)) | ||
return new MXDataIter(out) | ||
} | ||
} | ||
|
||
|
||
/** | ||
* class batch of data | ||
* @param data | ||
* @param label | ||
* @param index | ||
* @param pad | ||
*/ | ||
case class DataBatch(val data: NDArray, | ||
val label: NDArray, | ||
val index: List[Long], | ||
val pad: Int) | ||
|
||
/** | ||
*DataIter object in mxnet. | ||
*/ | ||
abstract class DataIter (val batchSize: Int = 0) { | ||
/** | ||
* reset the iterator | ||
*/ | ||
def reset(): Unit | ||
/** | ||
* Iterate to next batch | ||
* @return whether the move is successful | ||
*/ | ||
def iterNext(): Boolean | ||
|
||
/** | ||
* get next data batch from iterator | ||
* @return | ||
*/ | ||
def next(): DataBatch = { | ||
return new DataBatch(getData(), getLabel(), getIndex(), getPad()) | ||
} | ||
|
||
/** | ||
* get data of current batch | ||
* @return the data of current batch | ||
*/ | ||
def getData(): NDArray | ||
|
||
/** | ||
* Get label of current batch | ||
* @return the label of current batch | ||
*/ | ||
def getLabel(): NDArray | ||
|
||
/** | ||
* get the number of padding examples | ||
* in current batch | ||
* @return number of padding examples in current batch | ||
*/ | ||
def getPad(): Int | ||
|
||
/** | ||
* the index of current batch | ||
* @return | ||
*/ | ||
def getIndex(): List[Long] | ||
|
||
} | ||
|
||
/** | ||
* DataIter built in MXNet. | ||
* @param handle the handle to the underlying C++ Data Iterator | ||
*/ | ||
class MXDataIter(val handle: DataIterHandle) extends DataIter { | ||
private val logger = LoggerFactory.getLogger(classOf[MXDataIter]) | ||
|
||
override def finalize() = { | ||
checkCall(_LIB.mxDataIterFree(handle)) | ||
} | ||
|
||
/** | ||
* reset the iterator | ||
*/ | ||
override def reset(): Unit = { | ||
checkCall(_LIB.mxDataIterBeforeFirst(handle)) | ||
} | ||
|
||
/** | ||
* Iterate to next batch | ||
* @return whether the move is successful | ||
*/ | ||
override def iterNext(): Boolean = { | ||
val next = new RefInt | ||
checkCall(_LIB.mxDataIterNext(handle, next)) | ||
return next.value > 0 | ||
} | ||
|
||
/** | ||
* get data of current batch | ||
* @return the data of current batch | ||
*/ | ||
override def getData(): NDArray = { | ||
val out = new NDArrayHandle | ||
checkCall(_LIB.mxDataIterGetData(handle, out)) | ||
return new NDArray(out, writable = false) | ||
} | ||
|
||
/** | ||
* Get label of current batch | ||
* @return the label of current batch | ||
*/ | ||
override def getLabel(): NDArray = { | ||
val out = new NDArrayHandle | ||
checkCall(_LIB.mxDataIterGetLabel(handle, out)) | ||
return new NDArray(out, writable = false) | ||
} | ||
|
||
/** | ||
* the index of current batch | ||
* @return | ||
*/ | ||
override def getIndex(): List[Long] = { | ||
val outIndex = new ListBuffer[Long] | ||
val outSize = new RefLong | ||
checkCall(_LIB.mxDataIterGetIndex(handle, outIndex, outSize)) | ||
return outIndex.toList | ||
} | ||
|
||
/** | ||
* get the number of padding examples | ||
* in current batch | ||
* @return number of padding examples in current batch | ||
*/ | ||
override def getPad(): MXUint = { | ||
val out = new MXUintRef | ||
checkCall(_LIB.mxDataIterGetPadNum(handle, out)) | ||
return out.value | ||
} | ||
} | ||
|
||
/** | ||
* To do | ||
*/ | ||
class ArrayDataIter() extends DataIter { | ||
/** | ||
* reset the iterator | ||
*/ | ||
override def reset(): Unit = ??? | ||
|
||
/** | ||
* get data of current batch | ||
* @return the data of current batch | ||
*/ | ||
override def getData(): NDArray = ??? | ||
|
||
/** | ||
* Get label of current batch | ||
* @return the label of current batch | ||
*/ | ||
override def getLabel(): NDArray = ??? | ||
|
||
/** | ||
* the index of current batch | ||
* @return | ||
*/ | ||
override def getIndex(): List[Long] = ??? | ||
|
||
/** | ||
* Iterate to next batch | ||
* @return whether the move is successful | ||
*/ | ||
override def iterNext(): Boolean = ??? | ||
|
||
/** | ||
* get the number of padding examples | ||
* in current batch | ||
* @return number of padding examples in current batch | ||
*/ | ||
override def getPad(): MXUint = ??? | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package ml.dmlc.mxnet | ||
|
||
import org.scalatest.{BeforeAndAfterAll, FunSuite} | ||
import scala.sys.process._ | ||
|
||
|
||
class IOSuite extends FunSuite with BeforeAndAfterAll { | ||
test("test MNISTIter") { | ||
//get data | ||
"./scripts/get_mnist_data.sh" ! | ||
|
||
val params = Map( | ||
"image" -> "data/train-images-idx3-ubyte", | ||
"label" -> "data/train-labels-idx1-ubyte", | ||
"data_shape" -> "(784,)", | ||
"batch_size" -> "100", | ||
"shuffle" -> "1", | ||
"flat" -> "1", | ||
"silent" -> "0", | ||
"seed" -> "10" | ||
) | ||
|
||
val mnistIter = IO.createIterator("MNISTIter", params) | ||
//test_loop | ||
mnistIter.reset() | ||
val nBatch = 600 | ||
var batchCount = 0 | ||
while(mnistIter.iterNext()) { | ||
val batch = mnistIter.next() | ||
batchCount+=1 | ||
} | ||
//test loop | ||
assert(nBatch === batchCount) | ||
//test reset | ||
mnistIter.reset() | ||
mnistIter.iterNext() | ||
val label0 = mnistIter.getLabel().toArray | ||
mnistIter.iterNext() | ||
mnistIter.iterNext() | ||
mnistIter.iterNext() | ||
mnistIter.reset() | ||
mnistIter.iterNext() | ||
val label1 = mnistIter.getLabel().toArray | ||
assert(label0 === label1) | ||
} | ||
|
||
|
||
/** | ||
* not work now | ||
*/ | ||
// test("test ImageRecordIter") { | ||
// //get data | ||
// //"./scripts/get_cifar_data.sh" ! | ||
// | ||
// val params = Map( | ||
// "path_imgrec" -> "data/cifar/train.rec", | ||
// "mean_img" -> "data/cifar/cifar10_mean.bin", | ||
// "rand_crop" -> "False", | ||
// "and_mirror" -> "False", | ||
// "shuffle" -> "False", | ||
// "data_shape" -> "(3,28,28)", | ||
// "batch_size" -> "100", | ||
// "preprocess_threads" -> "4", | ||
// "prefetch_buffer" -> "1" | ||
// ) | ||
// val img_iter = IO.createIterator("ImageRecordIter", params) | ||
// img_iter.reset() | ||
// while(img_iter.iterNext()) { | ||
// val batch = img_iter.next() | ||
// } | ||
// } | ||
|
||
// test("test NDarryIter") { | ||
// | ||
// } | ||
} |
Oops, something went wrong.