Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
review require() usages to add meaningful messages.
Browse files Browse the repository at this point in the history
  • Loading branch information
mdespriee committed Sep 15, 2018
1 parent 52116d4 commit 2790d7d
Show file tree
Hide file tree
Showing 26 changed files with 285 additions and 184 deletions.
47 changes: 34 additions & 13 deletions scala-package/core/src/main/scala/org/apache/mxnet/EvalMetric.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,25 +133,30 @@ class TopKAccuracy(topK: Int) extends EvalMetric("top_k_accuracy") {

override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = {
require(labels.length == preds.length,
"labels and predictions should have the same length.")
s"labels and predictions should have the same length " +
s"(got ${labels.length} and ${preds.length}).")

for ((pred, label) <- preds zip labels) {
val predShape = pred.shape
val dims = predShape.length
require(dims <= 2, "Predictions should be no more than 2 dims.")
require(dims <= 2, s"Predictions should be no more than 2 dims (got $predShape).")
val labelArray = label.toArray
val numSamples = predShape(0)
if (dims == 1) {
val predArray = pred.toArray.zipWithIndex.sortBy(_._1).reverse.map(_._2)
require(predArray.length == labelArray.length)
require(predArray.length == labelArray.length,
s"Each label and prediction array should have the same length " +
s"(got ${labelArray.length} and ${predArray.length}).")
this.sumMetric +=
labelArray.zip(predArray).map { case (l, p) => if (l == p) 1 else 0 }.sum
} else if (dims == 2) {
val numclasses = predShape(1)
val predArray = pred.toArray.grouped(numclasses).map { a =>
a.zipWithIndex.sortBy(_._1).reverse.map(_._2)
}.toArray
require(predArray.length == labelArray.length)
require(predArray.length == labelArray.length,
s"Each label and prediction array should have the same length " +
s"(got ${labelArray.length} and ${predArray.length}).")
val topK = Math.max(this.topK, numclasses)
for (j <- 0 until topK) {
this.sumMetric +=
Expand All @@ -169,7 +174,8 @@ class TopKAccuracy(topK: Int) extends EvalMetric("top_k_accuracy") {
class F1 extends EvalMetric("f1") {
override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = {
require(labels.length == preds.length,
"labels and predictions should have the same length.")
s"labels and predictions should have the same length " +
s"(got ${labels.length} and ${preds.length}).")

for ((pred, label) <- preds zip labels) {
val predLabel = NDArray.argmax_channel(pred)
Expand Down Expand Up @@ -223,7 +229,8 @@ class F1 extends EvalMetric("f1") {
class Perplexity(ignoreLabel: Option[Int] = None, axis: Int = -1) extends EvalMetric("Perplexity") {
override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = {
require(labels.length == preds.length,
"labels and predictions should have the same length.")
s"labels and predictions should have the same length " +
s"(got ${labels.length} and ${preds.length}).")
var loss = 0d
var num = 0
val probs = ArrayBuffer[NDArray]()
Expand Down Expand Up @@ -261,12 +268,16 @@ class Perplexity(ignoreLabel: Option[Int] = None, axis: Int = -1) extends EvalMe
*/
class MAE extends EvalMetric("mae") {
override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = {
require(labels.size == preds.size, "labels and predictions should have the same length.")
require(labels.size == preds.size,
s"labels and predictions should have the same length " +
s"(got ${labels.length} and ${preds.length}).")

for ((label, pred) <- labels zip preds) {
val labelArr = label.toArray
val predArr = pred.toArray
require(labelArr.length == predArr.length)
require(labelArr.length == predArr.length,
s"Each label and prediction array should have the same length " +
s"(got ${labelArr.length} and ${predArr.length}).")
this.sumMetric +=
(labelArr zip predArr).map { case (l, p) => Math.abs(l - p) }.sum / labelArr.length
this.numInst += 1
Expand All @@ -277,12 +288,16 @@ class MAE extends EvalMetric("mae") {
// Calculate Mean Squared Error loss
class MSE extends EvalMetric("mse") {
override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = {
require(labels.size == preds.size, "labels and predictions should have the same length.")
require(labels.size == preds.size,
s"labels and predictions should have the same length " +
s"(got ${labels.length} and ${preds.length}).")

for ((label, pred) <- labels zip preds) {
val labelArr = label.toArray
val predArr = pred.toArray
require(labelArr.length == predArr.length)
require(labelArr.length == predArr.length,
s"Each label and prediction array should have the same length " +
s"(got ${labelArr.length} and ${predArr.length}).")
this.sumMetric +=
(labelArr zip predArr).map { case (l, p) => (l - p) * (l - p) }.sum / labelArr.length
this.numInst += 1
Expand All @@ -295,12 +310,16 @@ class MSE extends EvalMetric("mse") {
*/
class RMSE extends EvalMetric("rmse") {
override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = {
require(labels.size == preds.size, "labels and predictions should have the same length.")
require(labels.size == preds.size,
s"labels and predictions should have the same length " +
s"(got ${labels.length} and ${preds.length}).")

for ((label, pred) <- labels zip preds) {
val labelArr = label.toArray
val predArr = pred.toArray
require(labelArr.length == predArr.length)
require(labelArr.length == predArr.length,
s"Each label and prediction array should have the same length " +
s"(got ${labelArr.length} and ${predArr.length}).")
val metric: Double = Math.sqrt(
(labelArr zip predArr).map { case (l, p) => (l - p) * (l - p) }.sum / labelArr.length)
this.sumMetric += metric.toFloat
Expand All @@ -318,7 +337,9 @@ class RMSE extends EvalMetric("rmse") {
class CustomMetric(fEval: (NDArray, NDArray) => Float,
name: String) extends EvalMetric(name) {
override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = {
require(labels.size == preds.size, "labels and predictions should have the same length.")
require(labels.size == preds.size,
s"labels and predictions should have the same length " +
s"(got ${labels.length} and ${preds.length}).")

for ((label, pred) <- labels zip preds) {
this.sumMetric += fEval(label, pred)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ object Executor {
// Get the dictionary given name and ndarray pairs.
private[mxnet] def getDict(names: Seq[String],
ndarrays: Seq[NDArray]): Map[String, NDArray] = {
require(names.toSet.size == names.length, "Duplicate names detected")
require(names.toSet.size == names.length, s"Duplicate names detected in ($names)")
(names zip ndarrays).toMap
}
}
Expand Down Expand Up @@ -86,7 +86,10 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
def reshape(partialShaping: Boolean = false, allowUpSizing: Boolean = false,
kwargs: Map[String, Shape]): Executor = {
val (argShapes, _, auxShapes) = this.symbol.inferShape(kwargs)
require(argShapes != null, "Insufficient argument shapes provided.")
// TODO: more precise error message should be provided by backend
require(argShapes != null, "Shape inference failed." +
s"Known shapes are $kwargs for symbol arguments ${symbol.listArguments()} " +
s"and aux states ${symbol.listAuxiliaryStates()}")

var newArgDict = Map[String, NDArray]()
var newGradDict = Map[String, NDArray]()
Expand Down Expand Up @@ -194,13 +197,13 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
* on outputs that are not a loss function.
*/
def backward(outGrads: Array[NDArray]): Unit = {
require(outGrads != null)
require(outGrads != null, "outGrads must not be null")
val ndArrayPtrs = outGrads.map(_.handle)
checkCall(_LIB.mxExecutorBackward(handle, ndArrayPtrs))
}

def backward(outGrad: NDArray): Unit = {
require(outGrad != null)
require(outGrad != null, "outGrads must not be null")
backward(Array(outGrad))
}

Expand Down Expand Up @@ -271,15 +274,15 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
if (argDict.contains(name)) {
array.copyTo(argDict(name))
} else {
require(allowExtraParams, s"Find name $name that is not in the arguments")
require(allowExtraParams, s"Provided name $name is not in the arguments")
}
}
if (auxParams != null) {
auxParams.foreach { case (name, array) =>
if (auxDict.contains(name)) {
array.copyTo(auxDict(name))
} else {
require(allowExtraParams, s"Find name $name that is not in the auxiliary states")
require(allowExtraParams, s"Provided name $name is not in the auxiliary states")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ private[mxnet] class DataParallelExecutorManager(private val symbol: Symbol,
if (workLoadList == null) {
workLoadList = Seq.fill(numDevice)(1f)
}
require(workLoadList.size == numDevice, "Invalid settings for work load.")
require(workLoadList.size == numDevice, "Invalid settings for work load. " +
s"Size (${workLoadList.size}) should match num devices ($numDevice)")

private val slices = ExecutorManager.splitInputSlice(trainData.batchSize, workLoadList)

Expand Down Expand Up @@ -212,13 +213,13 @@ private[mxnet] object ExecutorManager {
private[mxnet] def checkArguments(symbol: Symbol): Unit = {
val argNames = symbol.listArguments()
require(argNames.toSet.size == argNames.length,
"Find duplicated argument name," +
"Found duplicated argument name," +
"please make the weight name non-duplicated(using name arguments)," +
s"arguments are $argNames")

val auxNames = symbol.listAuxiliaryStates()
require(auxNames.toSet.size == auxNames.length,
"Find duplicated auxiliary param name," +
"Found duplicated auxiliary param name," +
"please make the weight name non-duplicated(using name arguments)," +
s"arguments are $auxNames")
}
Expand Down Expand Up @@ -272,15 +273,21 @@ private[mxnet] object ExecutorManager {
sharedDataArrays: mutable.Map[String, NDArray] = null,
inputTypes: ListMap[String, DType] = null) = {
val (argShape, _, auxShape) = sym.inferShape(inputShapes)
require(argShape != null)
// TODO: more precise error message should be provided by backend
require(argShape != null, "Shape inference failed." +
s"Known shapes are $inputShapes for symbol arguments ${sym.listArguments()} " +
s"and aux states ${sym.listAuxiliaryStates()}")

val inputTypesUpdate =
if (inputTypes == null) {
inputShapes.map { case (key, _) => (key, Base.MX_REAL_TYPE) }
} else {
inputTypes
}
val (argTypes, _, auxTypes) = sym.inferType(inputTypesUpdate)
require(argTypes != null)
require(argTypes != null, "Type inference failed." +
s"Known types as $inputTypes for symbol arguments ${sym.listArguments()} " +
s"and aux states ${sym.listAuxiliaryStates()}")

val argArrays = ArrayBuffer.empty[NDArray]
val gradArrays: mutable.Map[String, NDArray] =
Expand Down Expand Up @@ -311,7 +318,8 @@ private[mxnet] object ExecutorManager {
val arr = sharedDataArrays(name)
if (arr.shape.product >= argShape(i).product) {
// good, we can share this memory
require(argTypes(i) == arr.dtype)
require(argTypes(i) == arr.dtype,
s"Type ${arr.dtype} of argument $name does not match inferred type ${argTypes(i)}")
arr.reshape(argShape(i))
} else {
DataParallelExecutorManager.logger.warn(
Expand Down Expand Up @@ -345,8 +353,10 @@ private[mxnet] object ExecutorManager {
NDArray.zeros(argShape(i), ctx, dtype = argTypes(i))
} else {
val arr = baseExec.argDict(name)
require(arr.shape == argShape(i))
require(arr.dtype == argTypes(i))
require(arr.shape == argShape(i),
s"Shape ${arr.shape} of argument $name does not match inferred shape ${argShape(i)}")
require(arr.dtype == argTypes(i),
s"Type ${arr.dtype} of argument $name does not match inferred type ${argTypes(i)}")
if (gradSet.contains(name)) {
gradArrays.put(name, baseExec.gradDict(name))
}
Expand All @@ -356,15 +366,20 @@ private[mxnet] object ExecutorManager {
}
}
// create or borrow aux variables
val auxNames = sym.listAuxiliaryStates()
val auxArrays =
if (baseExec == null) {
(auxShape zip auxTypes) map { case (s, t) =>
NDArray.zeros(s, ctx, dtype = t)
}
} else {
baseExec.auxArrays.zipWithIndex.map { case (a, i) =>
require(auxShape(i) == a.shape)
require(auxTypes(i) == a.dtype)
require(auxShape(i) == a.shape,
s"Shape ${a.shape} of aux variable ${auxNames(i)} does not match " +
s"inferred shape ${auxShape(i)}")
require(auxTypes(i) == a.dtype,
s"Type ${a.dtype} of aux variable ${auxNames(i)} does not match " +
s"inferred type ${auxTypes(i)}")
a
}.toSeq
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class FeedForward private(
// verify the argument of the default symbol and user provided parameters
def checkArguments(): Unit = {
if (!argumentChecked) {
require(symbol != null)
require(symbol != null, "Symbol must not be null")
// check if symbol contain duplicated names.
ExecutorManager.checkArguments(symbol)
// rematch parameters to delete useless ones
Expand Down Expand Up @@ -169,7 +169,9 @@ class FeedForward private(
private def initPredictor(inputShapes: Map[String, Shape]): Unit = {
if (this.predExec != null) {
val (argShapes, _, _) = symbol.inferShape(inputShapes)
require(argShapes != null, "Incomplete input shapes")
require(argShapes != null, "Shape inference failed." +
s"Known shapes are $inputShapes for symbol arguments ${symbol.listArguments()} " +
s"and aux states ${symbol.listAuxiliaryStates()}")
val predShapes = this.predExec.argArrays.map(_.shape)
if (argShapes.sameElements(predShapes)) {
return
Expand All @@ -187,7 +189,8 @@ class FeedForward private(
require(y != null || !isTrain, "y must be specified")
val label = if (y == null) NDArray.zeros(X.shape(0)) else y
require(label.shape.length == 1, "Label must be 1D")
require(X.shape(0) == label.shape(0), "The numbers of data points and labels not equal")
require(X.shape(0) == label.shape(0),
s"The numbers of data points (${X.shape(0)}) and labels (${label.shape(0)}) are not equal")
if (isTrain) {
new NDArrayIter(IndexedSeq(X), IndexedSeq(label), batchSize,
shuffle = isTrain, lastBatchHandle = "roll_over")
Expand Down Expand Up @@ -402,7 +405,7 @@ class FeedForward private(
* - ``prefix-epoch.params`` will be saved for parameters.
*/
def save(prefix: String, epoch: Int = this.numEpoch): Unit = {
require(epoch >= 0)
require(epoch >= 0, s"epoch must be >=0 (got $epoch)")
Model.saveCheckpoint(prefix, epoch, this.symbol, getArgParams, getAuxParams)
}

Expand Down
9 changes: 4 additions & 5 deletions scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ object IO {
defaultName: String,
defaultDType: DType,
defaultLayout: String): IndexedSeq[(DataDesc, NDArray)] = {
require(data != null)
require(data != IndexedSeq.empty || allowEmpty)
require(data != null, "data is required.")
require(data != IndexedSeq.empty || allowEmpty,
s"data should not be empty when allowEmpty is false")
if (data == IndexedSeq.empty) {
IndexedSeq()
} else if (data.length == 1) {
Expand Down Expand Up @@ -372,9 +373,7 @@ abstract class DataPack() extends Iterable[DataBatch] {
case class DataDesc(name: String, shape: Shape,
dtype: DType = DType.Float32, layout: String = Layout.UNDEFINED) {
require(layout == Layout.UNDEFINED || shape.length == layout.length,
("number of dimensions in shape :%d with" +
" shape: %s should match the length of the layout: %d with layout: %s").
format(shape.length, shape.toString, layout.length, layout))
s"number of dimensions in $shape should match the layout $layout")

override def toString(): String = {
s"DataDesc[$name,$shape,$dtype,$layout]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ abstract class Initializer {
*/
class Mixed(protected val patterns: List[String],
protected val initializers: List[Initializer]) extends Initializer {
require(patterns.length == initializers.length)
require(patterns.length == initializers.length,
"Should provide a pattern for each initializer")
private val map = patterns.map(_.r).zip(initializers)

override def apply(name: String, arr: NDArray): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ object Model {
argParams: Map[String, NDArray],
paramNames: IndexedSeq[String],
updateOnKVStore: Boolean): Unit = {
require(paramArrays.length == paramNames.length)
require(paramArrays.length == paramNames.length,
s"Provided parameter arrays does not match parameter names")
for (idx <- 0 until paramArrays.length) {
val paramOnDevs = paramArrays(idx)
val name = paramNames(idx)
Expand Down
Loading

0 comments on commit 2790d7d

Please sign in to comment.