Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-689] add DataDesc type for the Scala Package #11844

Merged
merged 25 commits into from
Aug 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,15 @@ object DType extends Enumeration {
case DType.Unknown => 0
}
}
private[mxnet] def getType(dtypeStr: String): DType = {
dtypeStr match {
case "UInt8" => DType.UInt8
case "Int32" => DType.Int32
case "Float16" => DType.Float16
case "Float32" => DType.Float32
case "Float64" => DType.Float64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise throw IllegalArgumentException

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in the recent commit

case _ => throw new IllegalArgumentException(
s"DType: $dtypeStr not found! please set it in DType.scala")
}
}
}
121 changes: 86 additions & 35 deletions scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.slf4j.LoggerFactory
import scala.annotation.varargs
import scala.collection.immutable.ListMap
import scala.collection.mutable.ListBuffer

/**
* IO iterators for loading training & validation data
*/
Expand Down Expand Up @@ -110,18 +109,22 @@ object IO {
}

// Convert data into canonical form.
private[mxnet] def initData(data: IndexedSeq[NDArray],
allowEmpty: Boolean,
defaultName: String): IndexedSeq[(String, NDArray)] = {
private[mxnet] def initDataDesc(data: IndexedSeq[NDArray],
allowEmpty: Boolean,
defaultName: String,
defaultDType: DType,
defaultLayout: String): IndexedSeq[(DataDesc, NDArray)] = {
require(data != null)
require(data != IndexedSeq.empty || allowEmpty)
if (data == IndexedSeq.empty) {
IndexedSeq()
} else if (data.length == 1) {
IndexedSeq((defaultName, data(0)))
IndexedSeq((new DataDesc(defaultName, data(0).shape,
defaultDType, defaultLayout), data(0)))
} else {
data.zipWithIndex.map(item => {
(defaultName + "_" + item._2, item._1)
(new DataDesc(defaultName + "_" + item._2, item._1.shape,
defaultDType, defaultLayout), item._1)
}).toIndexedSeq
}
}
Expand All @@ -136,11 +139,28 @@ class DataBatch(val data: IndexedSeq[NDArray],
val pad: Int,
// the key for the bucket that should be used for this batch,
// for bucketing io only
val bucketKey: AnyRef = null,
// use ListMap to indicate the order of data/label loading
val bucketKey: AnyRef,
// use DataDesc to indicate the order of data/label loading
// (must match the order of input data/label)
private val providedData: ListMap[String, Shape] = null,
private val providedLabel: ListMap[String, Shape] = null) {
private val providedDataDesc: IndexedSeq[DataDesc],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you want to do the same as Andrew suggested to keep the NDArray and DataDesc together here as well?

private val providedLabelDesc: IndexedSeq[DataDesc]) {
// TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)]
// However, since the data and label can be accessed publicly (no getter and setter)
// the change on this will break BC
def this(data: IndexedSeq[NDArray],
label: IndexedSeq[NDArray],
index: IndexedSeq[Long],
pad: Int,
// the key for the bucket that should be used for this batch,
// for bucketing io only
bucketKey: AnyRef = null,
// use ListMap to indicate the order of data/label loading
// (must match the order of input data/label)
providedData: ListMap[String, Shape] = null,
providedLabel: ListMap[String, Shape] = null) {
this(data, label, index, pad, bucketKey,
DataDesc.ListMap2Descs(providedData), DataDesc.ListMap2Descs(providedLabel))
}
/**
* Dispose its data and labels
* The object shall never be used after it is disposed.
Expand All @@ -155,10 +175,29 @@ class DataBatch(val data: IndexedSeq[NDArray],
}

// The name and shape of data
def provideData: ListMap[String, Shape] = providedData
def provideData: ListMap[String, Shape] = {
var temp = ListMap[String, Shape]()
if (providedDataDesc == null) null
else {
providedDataDesc.foreach(ele => temp = temp + (ele.name -> ele.shape))
temp
}
}

// The name and shape of label
def provideLabel: ListMap[String, Shape] = providedLabel
def provideLabel: ListMap[String, Shape] = {
var temp = ListMap[String, Shape]()
if (providedLabelDesc == null) null
else {
providedLabelDesc.foreach(ele => temp = temp + (ele.name -> ele.shape))
temp
}
}

def provideDataDesc: IndexedSeq[DataDesc] = providedDataDesc

def provideLabelDesc: IndexedSeq[DataDesc] = providedLabelDesc

}

object DataBatch {
Expand All @@ -171,8 +210,8 @@ object DataBatch {
private var index: IndexedSeq[Long] = null
private var pad: Int = 0
private var bucketKey: AnyRef = null
private var datatShapes: ListMap[String, Shape] = null
private var labelShapes: ListMap[String, Shape] = null
private var dataDesc: IndexedSeq[DataDesc] = null
private var labelDesc: IndexedSeq[DataDesc] = null

/**
* Set the input data.
Expand Down Expand Up @@ -228,37 +267,27 @@ object DataBatch {

/**
* Provide the shape of a data.
* @param name data name.
* @param shape data shape.
* @param dataDesc DataDescriptor
* @return this.
*/
def provideDataShape(name: String, shape: Shape): Builder = {
if (datatShapes == null) {
datatShapes = ListMap((name, shape))
} else {
datatShapes = datatShapes.updated(name, shape)
}
def provideDataDesc(dataDesc: IndexedSeq[DataDesc]): Builder = {
this.dataDesc = dataDesc
this
}

/**
* Provide the shape of a label.
* @param name label name.
* @param shape label shape.
* @param labelDesc LabelDescriptor
* @return this.
*/
def provideLabelShape(name: String, shape: Shape): Builder = {
if (labelShapes == null) {
labelShapes = ListMap((name, shape))
} else {
labelShapes = labelShapes.updated(name, shape)
}
def provideLabelDesc(labelDesc: IndexedSeq[DataDesc]): Builder = {
this.labelDesc = labelDesc
this
}

def build(): DataBatch = {
require(data != null, "data is required.")
new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes)
new DataBatch(data, label, index, pad, bucketKey, dataDesc, labelDesc)
}
}
}
Expand All @@ -280,7 +309,8 @@ abstract class DataIter extends Iterator[DataBatch] {
*/
@throws(classOf[NoSuchElementException])
def next(): DataBatch = {
new DataBatch(getData(), getLabel(), getIndex(), getPad())
new DataBatch(getData(), getLabel(), getIndex(), getPad(),
null, null, null)
}

/**
Expand Down Expand Up @@ -309,11 +339,19 @@ abstract class DataIter extends Iterator[DataBatch] {
def getIndex(): IndexedSeq[Long]

// The name and shape of data provided by this iterator
@deprecated
def provideData: ListMap[String, Shape]

// The name and shape of label provided by this iterator
@deprecated
def provideLabel: ListMap[String, Shape]

// Provide type:DataDesc of the data
def provideDataDesc: IndexedSeq[DataDesc]

// Provide type:DataDesc of the label
def provideLabelDesc: IndexedSeq[DataDesc]

// For bucketing io only
// The bucket key for the default symbol.
def defaultBucketKey: AnyRef = null
Expand All @@ -332,8 +370,9 @@ abstract class DataPack() extends Iterable[DataBatch] {

// Named data desc description contains name, shape, type and other extended attributes.
case class DataDesc(name: String, shape: Shape,
dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") {
require(shape.length == layout.length, ("number of dimensions in shape :%d with" +
dtype: DType = DType.Float32, layout: String = Layout.UNDEFINED) {
require(layout == Layout.UNDEFINED || shape.length == layout.length,
("number of dimensions in shape :%d with" +
" shape: %s should match the length of the layout: %d with layout: %s").
format(shape.length, shape.toString, layout.length, layout))

Expand All @@ -343,6 +382,8 @@ case class DataDesc(name: String, shape: Shape,
}

object DataDesc {

private val logger = LoggerFactory.getLogger(classOf[DataDesc])
/**
* Get the dimension that corresponds to the batch size.
* @param layout layout string. For example, "NCHW".
Expand All @@ -352,9 +393,19 @@ object DataDesc {
* for each data-parallelism device.
*/
def getBatchAxis(layout: Option[String]): Int = {
layout.map(_.indexOf('N')).getOrElse(0)
if (layout.isEmpty|| layout.get == Layout.UNDEFINED) {
logger.warn("Found Undefined Layout, will use default index 0 for batch axis")
0
} else {
if (layout.get.contains('N')) {
layout.get.indexOf("N")
} else {
throw new IllegalArgumentException("no Batch Axis('N') found in Layout!")
}
}
}

@deprecated
implicit def ListMap2Descs(shapes: ListMap[String, Shape]): IndexedSeq[DataDesc] = {
if (shapes != null) {
shapes.map { case (k, s) => new DataDesc(k, s) }.toIndexedSeq
Expand Down
35 changes: 35 additions & 0 deletions scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet

/**
* Layout definition of DataDesc
* N Batch size
* C channels
* H Height
* W Weight
* T sequence length
* __undefined__ default value of Layout
*/
object Layout {
val UNDEFINED = "__undefined__"
val NCHW = "NCHW"
val NTC = "NTC"
val NT = "NT"
val N = "N"
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ import java.io.ByteArrayInputStream

/**
* Scala interface for read/write RecordIO data format
*
* @author Depeng Liang
*
* @param uri, path to recordIO file.
* @param flag, RecordIO.IORead for reading or RecordIO.Write for writing.
*/
Expand Down Expand Up @@ -144,7 +141,7 @@ object MXRecordIO {
*
* @author Depeng Liang
*
* @param idx_path, path to index file
* @param idxPath, path to index file
* @param uri, path to recordIO file.
* @param flag, RecordIO.IORead for reading or RecordIO.Write for writing.
* @param keyType, data type for keys.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.mxnet.io

import org.apache.mxnet.Base._
import org.apache.mxnet.{DataBatch, DataIter, DataPack, NDArray, Shape, WarnIfNotDisposed}
import org.apache.mxnet.DType.DType
import org.apache.mxnet._
import org.apache.mxnet.IO._
import org.slf4j.LoggerFactory

Expand All @@ -41,21 +42,31 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
// fix me if any better way found)
private var currentBatch: DataBatch = null

private val (_provideData: ListMap[String, Shape],
private val (_provideDataDesc: IndexedSeq[DataDesc],
_provideLabelDesc: IndexedSeq[DataDesc],
_provideData: ListMap[String, Shape],
_provideLabel: ListMap[String, Shape],
_batchSize: Int) =
_batchSize: Int) = {
if (hasNext) {
iterNext()
val data = currentBatch.data(0)
val label = currentBatch.label(0)
// properties
val res = (ListMap(dataName -> data.shape), ListMap(labelName -> label.shape), data.shape(0))
val res = (
// TODO: need to allow user to specify DType and Layout
IndexedSeq(new DataDesc(dataName, data.shape, DType.Float32, Layout.UNDEFINED)),
IndexedSeq(new DataDesc(labelName, label.shape, DType.Float32, Layout.UNDEFINED)),
ListMap(dataName -> data.shape),
ListMap(labelName -> label.shape),
data.shape(0))
currentBatch.dispose()
reset()
res
} else {
(null, null, 0)
(null, null, null, null, 0)
}
}


private var disposed = false
protected def isDisposed = disposed
Expand Down Expand Up @@ -101,10 +112,12 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
private def iterNext(): Boolean = {
val next = new RefInt
checkCall(_LIB.mxDataIterNext(handle, next))
currentBatch = null
if (next.value > 0) {
currentBatch = new DataBatch(data = getData(), label = getLabel(),
index = getIndex(), pad = getPad())
index = getIndex(), pad = getPad(),
null, null, null)
} else {
currentBatch = null
}
next.value > 0
}
Expand Down Expand Up @@ -152,11 +165,19 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
}

// The name and shape of data provided by this iterator
@deprecated
override def provideData: ListMap[String, Shape] = _provideData

// The name and shape of label provided by this iterator
@deprecated
override def provideLabel: ListMap[String, Shape] = _provideLabel

// Provide type:DataDesc of the data
override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc

// Provide type:DataDesc of the label
override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc

override def hasNext: Boolean = {
if (currentBatch != null) {
true
Expand Down
Loading