From 397a96ba9681a6fb597baa9492e812eb29860afd Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 31 Dec 2015 00:01:30 +0800 Subject: [PATCH 1/3] Format Scala codes according to scalastyle-config.xml (mostly follows Apache Spark) --- scala-package/core/scalastyle-config.xml | 145 ++++++++++++++ .../src/main/resources/scalastyle_config.xml | 9 - .../main/scala/ml/dmlc/mxnet/EvalMetric.scala | 3 +- .../main/scala/ml/dmlc/mxnet/Executor.scala | 10 +- .../src/main/scala/ml/dmlc/mxnet/IO.scala | 189 +++++++++--------- .../scala/ml/dmlc/mxnet/Initializer.scala | 20 +- .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 4 +- .../src/main/scala/ml/dmlc/mxnet/Model.scala | 6 +- .../main/scala/ml/dmlc/mxnet/Monitor.scala | 4 +- .../main/scala/ml/dmlc/mxnet/NDArray.scala | 52 +++-- .../src/main/scala/ml/dmlc/mxnet/Random.scala | 21 +- .../test/scala/ml/dmlc/mxnet/IOSuite.scala | 8 +- .../scala/ml/dmlc/mxnet/KVStoreSuite.scala | 2 + scala-package/pom.xml | 4 +- 14 files changed, 321 insertions(+), 156 deletions(-) create mode 100644 scala-package/core/scalastyle-config.xml delete mode 100644 scala-package/core/src/main/resources/scalastyle_config.xml diff --git a/scala-package/core/scalastyle-config.xml b/scala-package/core/scalastyle-config.xml new file mode 100644 index 000000000000..d0ec81c558b4 --- /dev/null +++ b/scala-package/core/scalastyle-config.xml @@ -0,0 +1,145 @@ + + + + + Scalastyle standard configuration + + + + + + + + + + + + + + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + + + + ^println$ + + + + + Class\.forName + + + + + + JavaConversions + Instead of importing implicits in scala.collection.JavaConversions._, import + scala.collection.JavaConverters._ and use .asScala / .asJava methods + + + + + + + + diff --git a/scala-package/core/src/main/resources/scalastyle_config.xml b/scala-package/core/src/main/resources/scalastyle_config.xml deleted file mode 100644 index f5b043c57cce..000000000000 --- a/scala-package/core/src/main/resources/scalastyle_config.xml +++ /dev/null @@ -1,9 +0,0 @@ - - Scalastyle standard configuration - - - - 800 - - - \ No newline at end of file diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala index 561ba19389a4..9544a6d15d34 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala @@ -50,7 +50,8 @@ class Accuracy extends EvalMetric("accuracy") { val pred: NDArray = preds.slice(i, i) val label: NDArray = labels.slice(i, i) -// require(label.shape(0) < predLabel.shape(0), "Should not have more predict labels than actual labels ") + // require(label.shape(0) < predLabel.shape(0), + // "Should not have more predict labels than actual labels ") }) } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala index cafbae36362c..b7f47f8e07b7 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala @@ -16,7 +16,8 @@ object Executor { * @param batchSize The number of samples in a mini-batch. * @param workLoadList The list of work load for different devices, in the same order as ctx * @return The split slices to get a specific slice. - * @throws IllegalArgumentException If there are two many splits such that some slice can be empty. + * @throws IllegalArgumentException + * If there are two many splits such that some slice can be empty. */ private def splitInputSlice[@specialized(Int, Float, Double) V] (batchSize: Int, workLoadList: Array[V]) @@ -84,6 +85,7 @@ object Executor { * @param symbol * @see Symbol.bind : to create executor */ +// scalastyle:off finalize class Executor(val handle: ExecutorHandle, val symbol: Symbol) { var argArrays: Array[NDArray] = null protected var gradArrays: Array[NDArray] = null @@ -131,7 +133,7 @@ class Executor(val handle: ExecutorHandle, val symbol: Symbol) { * This parameter is only needed when bind is called * on outputs that are not a loss function. */ - def backward(outGrads: Array[NDArray]):Unit = { + def backward(outGrads: Array[NDArray]): Unit = { require(outGrads != null) val ndArrayPtrs = outGrads.map(_.handle.value) checkCall(_LIB.mxExecutorBackward(handle, outGrads.length, ndArrayPtrs)) @@ -188,7 +190,8 @@ class Executor(val handle: ExecutorHandle, val symbol: Symbol) { * Whether allow extra parameters that are not needed by symbol * If this is True, no error will be thrown when arg_params or aux_params * contain extra parameters that is not needed by the executor. - * @throws IllegalArgumentException If there is additional parameters in the dict but allow_extra_params=False + * @throws IllegalArgumentException + * If there is additional parameters in the dict but allow_extra_params=False */ def copyParamsFrom(argParams: Map[String, NDArray], auxParams: Map[String, NDArray], @@ -229,3 +232,4 @@ class Executor(val handle: ExecutorHandle, val symbol: Symbol) { str.value } } +// scalastyle:on finalize diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index 11dcfadbbcff..d4bdfc4aceb8 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -12,19 +12,19 @@ object IO { private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule() /** - * create iterator via iterName and params - * @param iterName name of iterator; "MNISTIter" or "ImageRecordIter" - * @param params paramters for create iterator - * @return - */ + * create iterator via iterName and params + * @param iterName name of iterator; "MNISTIter" or "ImageRecordIter" + * @param params paramters for create iterator + * @return + */ def createIterator(iterName: String, params: Map[String, String]): DataIter = { - return iterCreateFuncs(iterName)(params) + iterCreateFuncs(iterName)(params) } /** - * initi all IO creator Functions - * @return - */ + * initi all IO creator Functions + * @return + */ private def _initIOModule(): Map[String, IterCreateFunc] = { val IterCreators = new ListBuffer[DataIterCreator] checkCall(_LIB.mxListDataIters(IterCreators)) @@ -41,15 +41,15 @@ object IO { val paramStr = Base.ctypes2docstring(argNames, argTypes, argDescs) val docStr = s"${name.value}\n${desc.value}\n\n$paramStr\n" logger.debug(docStr) - return (name.value, creator(handle)) + (name.value, creator(handle)) } /** - * - * @param handle - * @param params - * @return - */ + * + * @param handle + * @param params + * @return + */ private def creator(handle: DataIterCreator)( params: Map[String, String]): DataIter = { val out = new DataIterHandle @@ -62,62 +62,63 @@ object IO { /** - * class batch of data - * @param data - * @param label - * @param index - * @param pad - */ -case class DataBatch(val data: NDArray, - val label: NDArray, - val index: List[Long], - val pad: Int) + * class batch of data + * @param data + * @param label + * @param index + * @param pad + */ +case class DataBatch(data: NDArray, + label: NDArray, + index: List[Long], + pad: Int) /** - *DataIter object in mxnet. - */ + * DataIter object in mxnet. + */ abstract class DataIter (val batchSize: Int = 0) { /** - * reset the iterator - */ + * reset the iterator + */ def reset(): Unit + /** - * Iterate to next batch - * @return whether the move is successful - */ + * Iterate to next batch + * @return whether the move is successful + */ def iterNext(): Boolean /** - * get next data batch from iterator - * @return - */ + * get next data batch from iterator + * @return + */ def next(): DataBatch = { - return new DataBatch(getData(), getLabel(), getIndex(), getPad()) + new DataBatch(getData(), getLabel(), getIndex(), getPad()) } /** - * get data of current batch - * @return the data of current batch - */ + * get data of current batch + * @return the data of current batch + */ def getData(): NDArray /** - * Get label of current batch - * @return the label of current batch - */ + * Get label of current batch + * @return the label of current batch + */ def getLabel(): NDArray /** - * get the number of padding examples - * in current batch - * @return number of padding examples in current batch - */ + * get the number of padding examples + * in current batch + * @return number of padding examples in current batch + */ def getPad(): Int /** - * the index of current batch - * @return - */ + * the index of current batch + * @return + */ def getIndex(): List[Long] } @@ -126,54 +127,55 @@ abstract class DataIter (val batchSize: Int = 0) { * DataIter built in MXNet. * @param handle the handle to the underlying C++ Data Iterator */ +// scalastyle:off finalize class MXDataIter(val handle: DataIterHandle) extends DataIter { private val logger = LoggerFactory.getLogger(classOf[MXDataIter]) - override def finalize() = { + override def finalize(): Unit = { checkCall(_LIB.mxDataIterFree(handle)) } /** - * reset the iterator - */ + * reset the iterator + */ override def reset(): Unit = { checkCall(_LIB.mxDataIterBeforeFirst(handle)) } /** - * Iterate to next batch - * @return whether the move is successful - */ + * Iterate to next batch + * @return whether the move is successful + */ override def iterNext(): Boolean = { val next = new RefInt checkCall(_LIB.mxDataIterNext(handle, next)) - return next.value > 0 + next.value > 0 } /** - * get data of current batch - * @return the data of current batch - */ + * get data of current batch + * @return the data of current batch + */ override def getData(): NDArray = { val out = new NDArrayHandle checkCall(_LIB.mxDataIterGetData(handle, out)) - return new NDArray(out, writable = false) + new NDArray(out, writable = false) } /** - * Get label of current batch - * @return the label of current batch - */ + * Get label of current batch + * @return the label of current batch + */ override def getLabel(): NDArray = { val out = new NDArrayHandle checkCall(_LIB.mxDataIterGetLabel(handle, out)) - return new NDArray(out, writable = false) + new NDArray(out, writable = false) } /** - * the index of current batch - * @return - */ + * the index of current batch + * @return + */ override def getIndex(): List[Long] = { val outIndex = new ListBuffer[Long] val outSize = new RefLong @@ -182,56 +184,57 @@ class MXDataIter(val handle: DataIterHandle) extends DataIter { } /** - * get the number of padding examples - * in current batch - * @return number of padding examples in current batch - */ + * get the number of padding examples + * in current batch + * @return number of padding examples in current batch + */ override def getPad(): MXUint = { val out = new MXUintRef checkCall(_LIB.mxDataIterGetPadNum(handle, out)) - return out.value + out.value } } +// scalastyle:on finalize /** - * To do - */ + * TODO + */ class ArrayDataIter() extends DataIter { /** - * reset the iterator - */ + * reset the iterator + */ override def reset(): Unit = ??? /** - * get data of current batch - * @return the data of current batch - */ + * get data of current batch + * @return the data of current batch + */ override def getData(): NDArray = ??? /** - * Get label of current batch - * @return the label of current batch - */ + * Get label of current batch + * @return the label of current batch + */ override def getLabel(): NDArray = ??? /** - * the index of current batch - * @return - */ + * the index of current batch + * @return + */ override def getIndex(): List[Long] = ??? /** - * Iterate to next batch - * @return whether the move is successful - */ + * Iterate to next batch + * @return whether the move is successful + */ override def iterNext(): Boolean = ??? /** - * get the number of padding examples - * in current batch - * @return number of padding examples in current batch - */ - override def getPad(): MXUint = ??? + * get the number of padding examples + * in current batch + * @return number of padding examples in current batch + */ + override def getPad(): Int = ??? } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala index 5287200065e2..e1952236db94 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala @@ -1,6 +1,6 @@ package ml.dmlc.mxnet -import ml.dmlc.mxnet.NDArray.{array, zeros, ones} +import ml.dmlc.mxnet.NDArray.array /** @@ -80,9 +80,9 @@ abstract class Initializer { * * @param scale The scale of uniform distribution */ -class Uniform(protected val scale: Float=0.07f) extends Initializer { +class Uniform(protected val scale: Float = 0.07f) extends Initializer { override def _initWeight(name: String, arr: NDArray): Unit = { - Random.uniform(-scale, scale, out=arr) + Random.uniform(-scale, scale, out = arr) } } @@ -92,9 +92,9 @@ class Uniform(protected val scale: Float=0.07f) extends Initializer { * * @param sigma Standard deviation for gaussian distribution. */ -class Normal(protected val sigma: Float=0.01f) extends Initializer { +class Normal(protected val sigma: Float = 0.01f) extends Initializer { override def _initWeight(name: String, arr: NDArray): Unit = { - Random.normal(0, sigma, out=arr) + Random.normal(0, sigma, out = arr) } } @@ -106,8 +106,8 @@ class Normal(protected val sigma: Float=0.01f) extends Initializer { * @param factorType Options are: "avg", "in", "out" * @param magnitude scale of random number range */ -class Xavier(protected val rndType: String ="uniform", - protected val factorType: String ="avg", +class Xavier(protected val rndType: String = "uniform", + protected val factorType: String = "avg", protected val magnitude: Int = 3) extends Initializer { override def _initWeight(name: String, arr: NDArray): Unit = { @@ -125,9 +125,9 @@ class Xavier(protected val rndType: String ="uniform", val scale = math.sqrt(magnitude / factor).toFloat rndType match { - case "uniform" => Random.uniform(-scale, scale, out=arr) - case "normal" => Random.normal(0, scale, out=arr) + case "uniform" => Random.uniform(-scale, scale, out = arr) + case "normal" => Random.normal(0, scale, out = arr) case _ => throw new IllegalArgumentException("Unknown random type") } } -} \ No newline at end of file +} diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index cdc69d857c8f..79e84b598c15 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -85,7 +85,7 @@ class LibInfo { @native def mxKVStoreGetGroupSize(handle: KVStoreHandle, size: RefInt): Int @native def mxKVStoreGetRank(handle: KVStoreHandle, size: RefInt): Int - //DataIter Funcs + // DataIter Funcs @native def mxListDataIters(handles: ListBuffer[DataIterCreator]): Int @native def mxDataIterCreateIter(handle: DataIterCreator, keys: Array[String], @@ -109,7 +109,7 @@ class LibInfo { outSize: RefLong): Int @native def mxDataIterGetPadNum(handle: DataIterHandle, out: MXUintRef): Int - //Executors + // Executors @native def mxExecutorOutputs(handle: ExecutorHandle, outputs: ArrayBuffer[NDArrayHandle]): Int @native def mxExecutorFree(handle: ExecutorHandle): Int @native def mxExecutorForward(handle: ExecutorHandle, isTrain: Int): Int diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala index 4277bc08c654..1912832341a3 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala @@ -16,14 +16,16 @@ object Model { * @param maxSize max size of the kvstore * @return Option of created [[KVStore]] and whether or not update weight on it */ - private def createKVStore(kvStore: String, numDevice: Int, maxSize: Int): (Option[KVStore], Boolean) = { + private def createKVStore(kvStore: String, + numDevice: Int, + maxSize: Int): (Option[KVStore], Boolean) = { if (numDevice == 1 && !kvStore.contains("dist")) { // no need to use kv for single device and single machine (None, false) } else { var kvType = kvStore if (kvType == "local") { - //automatically select a proper local + // automatically select a proper local kvType = if (maxSize < 1024 * 1024 * 16) { "local_update_cpu" diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala index bc48a0514af1..5ef6c8eaec81 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala @@ -28,13 +28,13 @@ class Monitor(protected val interval: Int, protected var statFunc: (NDArray) => private var activated: Boolean = false private var queue = new mutable.Queue[(Int, String, NDArray)] private var step: Int = 0 - private var exes = new mutable.Queue[Executor] + private var exes = new mutable.Queue[Executor] val statHelper: MXMonitorCallback = new MXMonitorCallback { override def invoke(name: String, arr: NDArrayHandle): Unit = { // wrapper for executor callback if (activated) { - val array = new NDArray(arr, writable=false) + val array = new NDArray(arr, writable = false) val elem = (step, name, statFunc(array)) queue += elem } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index 0429313b6db1..cb09451a10da 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -39,7 +39,9 @@ object NDArray { } // internal NDArray function - private[mxnet] def _unaryNDArrayFunction(funcName: String, src: NDArray, out: NDArray = null): NDArray = { + private[mxnet] def _unaryNDArrayFunction(funcName: String, + src: NDArray, + out: NDArray = null): NDArray = { var output = out val function = functions(funcName) require(function != null, s"invalid function name $funcName") @@ -112,7 +114,9 @@ object NDArray { * * @return a new empty ndarray handle */ - private def _newAllocHandle(shape: Array[Int], ctx: Context, delayAlloc: Boolean): NDArrayHandle = { + private def _newAllocHandle(shape: Array[Int], + ctx: Context, + delayAlloc: Boolean): NDArrayHandle = { val hdl = new NDArrayHandle checkCall(_LIB.mxNDArrayCreate( shape, @@ -169,7 +173,11 @@ object NDArray { } else if (nMutateVars.value == 1 && nUsedVars.value == 1 && nScalars.value == 0) { (name.value, UnaryNDArrayFunction(handle, acceptEmptyMutate)) } else { - (name.value, GenericNDArrayFunction(handle, acceptEmptyMutate, nMutateVars.value, useVarsRange, scalarRange)) + (name.value, GenericNDArrayFunction(handle, + acceptEmptyMutate, + nMutateVars.value, + useVarsRange, + scalarRange)) } } @@ -196,7 +204,7 @@ object NDArray { * * @return The created NDArray. */ - def empty(shape: Array[Int], ctx: Context=null): NDArray = { + def empty(shape: Array[Int], ctx: Context = null): NDArray = { val context = if (ctx == null) Context.defaultCtx else ctx new NDArray(handle = NDArray._newAllocHandle(shape, context, delayAlloc = false)) } @@ -213,7 +221,7 @@ object NDArray { * * @return The created NDArray. */ - def zeros(shape: Array[Int], ctx: Context=null): NDArray = { + def zeros(shape: Array[Int], ctx: Context = null): NDArray = { val arr = empty(shape, ctx) arr.set(0f) arr @@ -229,7 +237,7 @@ object NDArray { * @param ctx The context of the NDArray, default to current default context. * @return The created NDArray. */ - def ones(shape: Array[Int], ctx: Context=null): NDArray = { + def ones(shape: Array[Int], ctx: Context = null): NDArray = { val arr = empty(shape, ctx) arr.set(1f) arr @@ -269,9 +277,9 @@ object NDArray { } // TODO - def _randomUniform(low: Float, high: Float, out: NDArray) = ??? + def _randomUniform(low: Float, high: Float, out: NDArray): NDArray = ??? - def _randomGaussian(mean: Float, stdvar: Float, out: NDArray) = ??? + def _randomGaussian(mean: Float, stdvar: Float, out: NDArray): NDArray = ??? /** @@ -280,7 +288,7 @@ object NDArray { * @param ctx The context of the NDArray, default to current default context. * @return The created NDArray. */ - def array(sourceArr: Array[Float], ctx: Context=null): NDArray = ??? + def array(sourceArr: Array[Float], ctx: Context = null): NDArray = ??? /** * Load ndarray from binary file. @@ -323,8 +331,9 @@ object NDArray { * NDArray object in mxnet. * NDArray is basic ndarray/Tensor like data structure in mxnet. */ +// scalastyle:off finalize class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) { - override def finalize() = { + override def finalize(): Unit = { checkCall(_LIB.mxNDArrayFree(handle)) } @@ -377,7 +386,7 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) { */ def set(value: Float): NDArray = { require(writable, "trying to assign to a readonly NDArray") - NDArray._genericNDArrayFunction("_set_value", Array[Any](value), out=Array(this)) + NDArray._genericNDArrayFunction("_set_value", Array[Any](value), out = Array(this)) this } @@ -404,14 +413,14 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) { if (!writable) { throw new IllegalArgumentException("trying to add to a readonly NDArray") } - NDArray._binaryNDArrayFunction("_plus", this, other, out=this) + NDArray._binaryNDArrayFunction("_plus", this, other, out = this) } def +=(other: Float): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to add to a readonly NDArray") } - NDArray._genericNDArrayFunction("_plus_scalar", Array[Any](this, other), out=Array(this)) + NDArray._genericNDArrayFunction("_plus_scalar", Array[Any](this, other), out = Array(this)) this } @@ -427,14 +436,14 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) { if (!writable) { throw new IllegalArgumentException("trying to subtract from a readonly NDArray") } - NDArray._binaryNDArrayFunction("_minus", this, other, out=this) + NDArray._binaryNDArrayFunction("_minus", this, other, out = this) } def -=(other: Float): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to subtract from a readonly NDArray") } - NDArray._genericNDArrayFunction("_minus_scalar", Array[Any](this, other), out=Array(this)) + NDArray._genericNDArrayFunction("_minus_scalar", Array[Any](this, other), out = Array(this)) this } @@ -450,18 +459,18 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) { NDArray._genericNDArrayFunction("_mul_scalar", Array[Any](this, -1f))(0) } - def *=(other: NDArray) = { + def *=(other: NDArray): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to multiply to a readonly NDArray") } - NDArray._binaryNDArrayFunction("_mul", this, other, out=this) + NDArray._binaryNDArrayFunction("_mul", this, other, out = this) } - def *=(other: Float) = { + def *=(other: Float): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to multiply to a readonly NDArray") } - NDArray._genericNDArrayFunction("_mul_scalar", Array[Any](this, other), out=Array(this)) + NDArray._genericNDArrayFunction("_mul_scalar", Array[Any](this, other), out = Array(this)) this } @@ -477,14 +486,14 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) { if (!writable) { throw new IllegalArgumentException("trying to divide from a readonly NDArray") } - NDArray._binaryNDArrayFunction("_div", this, other, out=this) + NDArray._binaryNDArrayFunction("_div", this, other, out = this) } def /=(other: Float): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to divide from a readonly NDArray") } - NDArray._genericNDArrayFunction("_div_scalar", Array[Any](this, other), out=Array(this)) + NDArray._genericNDArrayFunction("_div_scalar", Array[Any](this, other), out = Array(this)) this } @@ -541,6 +550,7 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) { // Get size of current NDArray. def size: Int = shape.product } +// scalastyle:on finalize object NDArrayConversions { implicit def int2Scalar(x: Int): NDArrayConversions = new NDArrayConversions(x.toFloat) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala index 9bcfe4eb162b..b7abf9627429 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala @@ -1,7 +1,6 @@ package ml.dmlc.mxnet -import ml.dmlc.mxnet.Base._ -import ml.dmlc.mxnet.NDArray.{_randomUniform, _randomGaussian, empty} +import ml.dmlc.mxnet.NDArray.{_randomGaussian, _randomUniform, empty} /** * Random Number interface of mxnet. @@ -18,7 +17,11 @@ object Random { * @param out Output place holder * @return The result NDArray with generated result. */ - def uniform(low: Float, high: Float, shape: Array[Int]=null, ctx: Context=null, out: NDArray=null): NDArray = { + def uniform(low: Float, + high: Float, + shape: Array[Int] = null, + ctx: Context = null, + out: NDArray = null): NDArray = { var outCopy = out if (outCopy != null) { require(shape == null & ctx == null, "shape and ctx is not needed when out is specified.") @@ -26,7 +29,7 @@ object Random { require(shape != null, "shape is required when out is not specified") outCopy = empty(shape, ctx) } - return _randomUniform(low, high, outCopy) + _randomUniform(low, high, outCopy) } @@ -40,7 +43,11 @@ object Random { * @param out Output place holder * @return The result NDArray with generated result. */ - def normal(mean: Float, stdvar: Float, shape: Array[Int]=null, ctx: Context=null, out: NDArray=null): NDArray = { + def normal(mean: Float, + stdvar: Float, + shape: Array[Int] = null, + ctx: Context = null, + out: NDArray = null): NDArray = { var outCopy = out if (outCopy != null) { require(shape == null & ctx == null, "shape and ctx is not needed when out is specified.") @@ -48,7 +55,7 @@ object Random { require(shape != null, "shape is required when out is not specified") outCopy = empty(shape, ctx) } - return _randomGaussian(mean, stdvar, outCopy) + _randomGaussian(mean, stdvar, outCopy) } @@ -64,7 +71,7 @@ object Random { * This means if you set the same seed, the random number sequence * generated from GPU0 can be different from CPU. */ - def seed(seedState: Int) = { + def seed(seedState: Int): Unit = { // TODO // checkCall(_LIB.mxRandomSeed(seedState)) } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala index fc6940cec50b..2991743da008 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala @@ -6,7 +6,7 @@ import scala.sys.process._ class IOSuite extends FunSuite with BeforeAndAfterAll { test("test MNISTIter") { - //get data + // get data "./scripts/get_mnist_data.sh" ! val params = Map( @@ -21,7 +21,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { ) val mnistIter = IO.createIterator("MNISTIter", params) - //test_loop + // test_loop mnistIter.reset() val nBatch = 600 var batchCount = 0 @@ -29,9 +29,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { val batch = mnistIter.next() batchCount+=1 } - //test loop + // test loop assert(nBatch === batchCount) - //test reset + // test reset mnistIter.reset() mnistIter.iterNext() val label0 = mnistIter.getLabel().toArray diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala index b237500281c8..145ba4d0b71b 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala @@ -28,7 +28,9 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll { val kv = KVStore.create() val updater = new MXKVStoreUpdater { override def update(key: Int, input: NDArray, stored: NDArray, handle: AnyRef): Unit = { + // scalastyle:off println println(s"update on key $key") + // scalastyle:on println stored += input * 2 } } diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 25c6b9cdcf1b..524c9a52eb0f 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -142,8 +142,8 @@ false ${basedir}/src/main/scala ${basedir}/src/test/scala - ${basedir}/src/main/resources/scalastyle_config.xml - ${project.basedir}/scalastyle_output.xml + ${basedir}/scalastyle-config.xml + ${basedir}/target/scalastyle-output.xml UTF-8 From def4e7078b688fc20c154f3749e8e0822da485f0 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 31 Dec 2015 11:55:04 +0800 Subject: [PATCH 2/3] scala-lint in travis --- .travis.yml | 1 + tests/travis/run_test.sh | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 23377bf0f562..4337223de846 100644 --- a/.travis.yml +++ b/.travis.yml @@ -19,6 +19,7 @@ env: #- TASK=python_test #- TASK=r_test - TASK=scala_test + - TASK=scala_lint # TODO, R test, distributed test, clang, more g++ versions diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index 43ed4cc922a0..0b747d9fc151 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -111,7 +111,6 @@ if [ ${TASK} == "python_test" ]; then exit 0 fi - if [ ${TASK} == "scala_test" ]; then if [ ${TRAVIS_OS_NAME} == "osx" ]; then LIB_GOMP_PATH=`find /usr/local/lib -name libgomp.dylib | grep -v i386 | head -n1` @@ -141,3 +140,9 @@ if [ ${TASK} == "scala_test" ]; then exit 0 fi + +if [ ${TASK} == "scala_lint" ]; then + cd scala-package/core + mvn scalastyle:check || exit -1 + exit 0 +fi From ab87ca9cfeca2747e54271378f8149173daf6db9 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Thu, 31 Dec 2015 13:41:56 +0800 Subject: [PATCH 3/3] scalastyle add return check (warning level) --- scala-package/core/scalastyle-config.xml | 3 +++ scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala | 4 ++-- tests/travis/run_test.sh | 1 + 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/scala-package/core/scalastyle-config.xml b/scala-package/core/scalastyle-config.xml index d0ec81c558b4..847bbc2babe9 100644 --- a/scala-package/core/scalastyle-config.xml +++ b/scala-package/core/scalastyle-config.xml @@ -142,4 +142,7 @@ You can also disable only one rule, by specifying its rule id, as specified in: + + + diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala index d4bdfc4aceb8..9b999e0d6b32 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala @@ -56,7 +56,7 @@ object IO { val keys = params.keys.toArray val vals = params.values.toArray checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out)) - return new MXDataIter(out) + new MXDataIter(out) } } @@ -180,7 +180,7 @@ class MXDataIter(val handle: DataIterHandle) extends DataIter { val outIndex = new ListBuffer[Long] val outSize = new RefLong checkCall(_LIB.mxDataIterGetIndex(handle, outIndex, outSize)) - return outIndex.toList + outIndex.toList } /** diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index 0b747d9fc151..cd91c0db89bc 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -142,6 +142,7 @@ if [ ${TASK} == "scala_test" ]; then fi if [ ${TASK} == "scala_lint" ]; then + export JAVA_HOME=$(/usr/libexec/java_home) cd scala-package/core mvn scalastyle:check || exit -1 exit 0