This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-689] add DataDesc type for the Scala Package #11844
Merged
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
21aaea8
add dataDesc
lanking520 0ca6088
Add amend
lanking520 88a0043
add changes with dataLayout and labelLayout
lanking520 b4fada3
add depreciate and example changes
lanking520 a2b6713
Gan and Customop fixes
lanking520 ad3f73c
change the DType
lanking520 559ed96
add one more class to convert Strings to DTypes
lanking520 33acc7b
convert layout to global
lanking520 ff34f96
scala style fix
lanking520 1252d1f
Revert to 8c7d1f8
lanking520 a0609d0
fix coding style issue
lanking520 75b518c
print full stacktraces
lanking520 ed92e73
apply changes to new constructor
lanking520 83b5826
add databatch bcc
lanking520 0e170cf
introduce undefined field
lanking520 67fed73
Fix crashes when change provideData to provideDataDesc
lanking520 46754f0
change spacing and revert test
lanking520 8c7fac1
apply DataDesc on DataBatch
lanking520 70aa7f4
unit test for NDArrayIter and MXDataiter
lanking520 e8d1a40
apply changes on CR
lanking520 44bb97e
change NDArrayIter and revert the rest
lanking520 09db459
revert change on examples
lanking520 3f862e4
apply final changes
lanking520 ad28059
remove the provideLabelShape
lanking520 93329a1
add TODO about the findings
lanking520 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,6 @@ import org.slf4j.LoggerFactory | |
import scala.annotation.varargs | ||
import scala.collection.immutable.ListMap | ||
import scala.collection.mutable.ListBuffer | ||
|
||
/** | ||
* IO iterators for loading training & validation data | ||
*/ | ||
|
@@ -110,18 +109,22 @@ object IO { | |
} | ||
|
||
// Convert data into canonical form. | ||
private[mxnet] def initData(data: IndexedSeq[NDArray], | ||
allowEmpty: Boolean, | ||
defaultName: String): IndexedSeq[(String, NDArray)] = { | ||
private[mxnet] def initDataDesc(data: IndexedSeq[NDArray], | ||
allowEmpty: Boolean, | ||
defaultName: String, | ||
defaultDType: DType, | ||
defaultLayout: String): IndexedSeq[(DataDesc, NDArray)] = { | ||
require(data != null) | ||
require(data != IndexedSeq.empty || allowEmpty) | ||
if (data == IndexedSeq.empty) { | ||
IndexedSeq() | ||
} else if (data.length == 1) { | ||
IndexedSeq((defaultName, data(0))) | ||
IndexedSeq((new DataDesc(defaultName, data(0).shape, | ||
defaultDType, defaultLayout), data(0))) | ||
} else { | ||
data.zipWithIndex.map(item => { | ||
(defaultName + "_" + item._2, item._1) | ||
(new DataDesc(defaultName + "_" + item._2, item._1.shape, | ||
defaultDType, defaultLayout), item._1) | ||
}).toIndexedSeq | ||
} | ||
} | ||
|
@@ -136,11 +139,28 @@ class DataBatch(val data: IndexedSeq[NDArray], | |
val pad: Int, | ||
// the key for the bucket that should be used for this batch, | ||
// for bucketing io only | ||
val bucketKey: AnyRef = null, | ||
// use ListMap to indicate the order of data/label loading | ||
val bucketKey: AnyRef, | ||
// use DataDesc to indicate the order of data/label loading | ||
// (must match the order of input data/label) | ||
private val providedData: ListMap[String, Shape] = null, | ||
private val providedLabel: ListMap[String, Shape] = null) { | ||
private val providedDataDesc: IndexedSeq[DataDesc], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't you want to do the same as Andrew suggested to keep the |
||
private val providedLabelDesc: IndexedSeq[DataDesc]) { | ||
// TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)] | ||
// However, since the data and label can be accessed publicly (no getter and setter) | ||
// the change on this will break BC | ||
def this(data: IndexedSeq[NDArray], | ||
label: IndexedSeq[NDArray], | ||
index: IndexedSeq[Long], | ||
pad: Int, | ||
// the key for the bucket that should be used for this batch, | ||
// for bucketing io only | ||
bucketKey: AnyRef = null, | ||
// use ListMap to indicate the order of data/label loading | ||
// (must match the order of input data/label) | ||
providedData: ListMap[String, Shape] = null, | ||
providedLabel: ListMap[String, Shape] = null) { | ||
this(data, label, index, pad, bucketKey, | ||
DataDesc.ListMap2Descs(providedData), DataDesc.ListMap2Descs(providedLabel)) | ||
} | ||
/** | ||
* Dispose its data and labels | ||
* The object shall never be used after it is disposed. | ||
|
@@ -155,10 +175,29 @@ class DataBatch(val data: IndexedSeq[NDArray], | |
} | ||
|
||
// The name and shape of data | ||
def provideData: ListMap[String, Shape] = providedData | ||
def provideData: ListMap[String, Shape] = { | ||
var temp = ListMap[String, Shape]() | ||
if (providedDataDesc == null) null | ||
else { | ||
providedDataDesc.foreach(ele => temp = temp + (ele.name -> ele.shape)) | ||
temp | ||
} | ||
} | ||
|
||
// The name and shape of label | ||
def provideLabel: ListMap[String, Shape] = providedLabel | ||
def provideLabel: ListMap[String, Shape] = { | ||
var temp = ListMap[String, Shape]() | ||
if (providedLabelDesc == null) null | ||
else { | ||
providedLabelDesc.foreach(ele => temp = temp + (ele.name -> ele.shape)) | ||
temp | ||
} | ||
} | ||
|
||
def provideDataDesc: IndexedSeq[DataDesc] = providedDataDesc | ||
|
||
def provideLabelDesc: IndexedSeq[DataDesc] = providedLabelDesc | ||
|
||
} | ||
|
||
object DataBatch { | ||
|
@@ -171,8 +210,8 @@ object DataBatch { | |
private var index: IndexedSeq[Long] = null | ||
private var pad: Int = 0 | ||
private var bucketKey: AnyRef = null | ||
private var datatShapes: ListMap[String, Shape] = null | ||
private var labelShapes: ListMap[String, Shape] = null | ||
private var dataDesc: IndexedSeq[DataDesc] = null | ||
private var labelDesc: IndexedSeq[DataDesc] = null | ||
|
||
/** | ||
* Set the input data. | ||
|
@@ -228,37 +267,27 @@ object DataBatch { | |
|
||
/** | ||
* Provide the shape of a data. | ||
* @param name data name. | ||
* @param shape data shape. | ||
* @param dataDesc DataDescriptor | ||
* @return this. | ||
*/ | ||
def provideDataShape(name: String, shape: Shape): Builder = { | ||
if (datatShapes == null) { | ||
datatShapes = ListMap((name, shape)) | ||
} else { | ||
datatShapes = datatShapes.updated(name, shape) | ||
} | ||
def provideDataDesc(dataDesc: IndexedSeq[DataDesc]): Builder = { | ||
this.dataDesc = dataDesc | ||
this | ||
} | ||
|
||
/** | ||
* Provide the shape of a label. | ||
* @param name label name. | ||
* @param shape label shape. | ||
* @param labelDesc LabelDescriptor | ||
* @return this. | ||
*/ | ||
def provideLabelShape(name: String, shape: Shape): Builder = { | ||
if (labelShapes == null) { | ||
labelShapes = ListMap((name, shape)) | ||
} else { | ||
labelShapes = labelShapes.updated(name, shape) | ||
} | ||
def provideLabelDesc(labelDesc: IndexedSeq[DataDesc]): Builder = { | ||
this.labelDesc = labelDesc | ||
this | ||
} | ||
|
||
def build(): DataBatch = { | ||
require(data != null, "data is required.") | ||
new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes) | ||
new DataBatch(data, label, index, pad, bucketKey, dataDesc, labelDesc) | ||
} | ||
} | ||
} | ||
|
@@ -280,7 +309,8 @@ abstract class DataIter extends Iterator[DataBatch] { | |
*/ | ||
@throws(classOf[NoSuchElementException]) | ||
def next(): DataBatch = { | ||
new DataBatch(getData(), getLabel(), getIndex(), getPad()) | ||
new DataBatch(getData(), getLabel(), getIndex(), getPad(), | ||
null, null, null) | ||
} | ||
|
||
/** | ||
|
@@ -309,11 +339,19 @@ abstract class DataIter extends Iterator[DataBatch] { | |
def getIndex(): IndexedSeq[Long] | ||
|
||
// The name and shape of data provided by this iterator | ||
@deprecated | ||
def provideData: ListMap[String, Shape] | ||
|
||
// The name and shape of label provided by this iterator | ||
@deprecated | ||
def provideLabel: ListMap[String, Shape] | ||
|
||
// Provide type:DataDesc of the data | ||
def provideDataDesc: IndexedSeq[DataDesc] | ||
|
||
// Provide type:DataDesc of the label | ||
def provideLabelDesc: IndexedSeq[DataDesc] | ||
|
||
// For bucketing io only | ||
// The bucket key for the default symbol. | ||
def defaultBucketKey: AnyRef = null | ||
|
@@ -332,8 +370,9 @@ abstract class DataPack() extends Iterable[DataBatch] { | |
|
||
// Named data desc description contains name, shape, type and other extended attributes. | ||
case class DataDesc(name: String, shape: Shape, | ||
dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") { | ||
require(shape.length == layout.length, ("number of dimensions in shape :%d with" + | ||
dtype: DType = DType.Float32, layout: String = Layout.UNDEFINED) { | ||
require(layout == Layout.UNDEFINED || shape.length == layout.length, | ||
("number of dimensions in shape :%d with" + | ||
" shape: %s should match the length of the layout: %d with layout: %s"). | ||
format(shape.length, shape.toString, layout.length, layout)) | ||
|
||
|
@@ -343,6 +382,8 @@ case class DataDesc(name: String, shape: Shape, | |
} | ||
|
||
object DataDesc { | ||
|
||
private val logger = LoggerFactory.getLogger(classOf[DataDesc]) | ||
/** | ||
* Get the dimension that corresponds to the batch size. | ||
* @param layout layout string. For example, "NCHW". | ||
|
@@ -352,9 +393,19 @@ object DataDesc { | |
* for each data-parallelism device. | ||
*/ | ||
def getBatchAxis(layout: Option[String]): Int = { | ||
layout.map(_.indexOf('N')).getOrElse(0) | ||
if (layout.isEmpty|| layout.get == Layout.UNDEFINED) { | ||
logger.warn("Found Undefined Layout, will use default index 0 for batch axis") | ||
0 | ||
} else { | ||
if (layout.get.contains('N')) { | ||
layout.get.indexOf("N") | ||
} else { | ||
throw new IllegalArgumentException("no Batch Axis('N') found in Layout!") | ||
} | ||
} | ||
} | ||
|
||
@deprecated | ||
implicit def ListMap2Descs(shapes: ListMap[String, Shape]): IndexedSeq[DataDesc] = { | ||
if (shapes != null) { | ||
shapes.map { case (k, s) => new DataDesc(k, s) }.toIndexedSeq | ||
|
35 changes: 35 additions & 0 deletions
35
scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.mxnet | ||
|
||
/** | ||
* Layout definition of DataDesc | ||
* N Batch size | ||
* C channels | ||
* H Height | ||
* W Weight | ||
* T sequence length | ||
* __undefined__ default value of Layout | ||
*/ | ||
object Layout { | ||
val UNDEFINED = "__undefined__" | ||
val NCHW = "NCHW" | ||
val NTC = "NTC" | ||
val NT = "NT" | ||
val N = "N" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
otherwise throw IllegalArgumentException
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in the recent commit