diff --git a/.travis.yml b/.travis.yml
index 23377bf0f562..4337223de846 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -19,6 +19,7 @@ env:
#- TASK=python_test
#- TASK=r_test
- TASK=scala_test
+ - TASK=scala_lint
# TODO, R test, distributed test, clang, more g++ versions
diff --git a/scala-package/core/scalastyle-config.xml b/scala-package/core/scalastyle-config.xml
new file mode 100644
index 000000000000..847bbc2babe9
--- /dev/null
+++ b/scala-package/core/scalastyle-config.xml
@@ -0,0 +1,148 @@
+
+
+
+
+ Scalastyle standard configuration
+
+
+
+
+
+
+
+
+
+
+
+
+
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW
+
+
+
+
+
+ ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW
+
+
+
+
+
+
+
+
+ ^println$
+
+
+
+
+ Class\.forName
+
+
+
+
+
+ JavaConversions
+ Instead of importing implicits in scala.collection.JavaConversions._, import
+ scala.collection.JavaConverters._ and use .asScala / .asJava methods
+
+
+
+
+
+
+
+
+
+
+
diff --git a/scala-package/core/src/main/resources/scalastyle_config.xml b/scala-package/core/src/main/resources/scalastyle_config.xml
deleted file mode 100644
index f5b043c57cce..000000000000
--- a/scala-package/core/src/main/resources/scalastyle_config.xml
+++ /dev/null
@@ -1,9 +0,0 @@
-
- Scalastyle standard configuration
-
-
-
- 800
-
-
-
\ No newline at end of file
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
index 561ba19389a4..9544a6d15d34 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala
@@ -50,7 +50,8 @@ class Accuracy extends EvalMetric("accuracy") {
val pred: NDArray = preds.slice(i, i)
val label: NDArray = labels.slice(i, i)
-// require(label.shape(0) < predLabel.shape(0), "Should not have more predict labels than actual labels ")
+ // require(label.shape(0) < predLabel.shape(0),
+ // "Should not have more predict labels than actual labels ")
})
}
}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala
index cafbae36362c..b7f47f8e07b7 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala
@@ -16,7 +16,8 @@ object Executor {
* @param batchSize The number of samples in a mini-batch.
* @param workLoadList The list of work load for different devices, in the same order as ctx
* @return The split slices to get a specific slice.
- * @throws IllegalArgumentException If there are two many splits such that some slice can be empty.
+ * @throws IllegalArgumentException
+ * If there are two many splits such that some slice can be empty.
*/
private def splitInputSlice[@specialized(Int, Float, Double) V]
(batchSize: Int, workLoadList: Array[V])
@@ -84,6 +85,7 @@ object Executor {
* @param symbol
* @see Symbol.bind : to create executor
*/
+// scalastyle:off finalize
class Executor(val handle: ExecutorHandle, val symbol: Symbol) {
var argArrays: Array[NDArray] = null
protected var gradArrays: Array[NDArray] = null
@@ -131,7 +133,7 @@ class Executor(val handle: ExecutorHandle, val symbol: Symbol) {
* This parameter is only needed when bind is called
* on outputs that are not a loss function.
*/
- def backward(outGrads: Array[NDArray]):Unit = {
+ def backward(outGrads: Array[NDArray]): Unit = {
require(outGrads != null)
val ndArrayPtrs = outGrads.map(_.handle.value)
checkCall(_LIB.mxExecutorBackward(handle, outGrads.length, ndArrayPtrs))
@@ -188,7 +190,8 @@ class Executor(val handle: ExecutorHandle, val symbol: Symbol) {
* Whether allow extra parameters that are not needed by symbol
* If this is True, no error will be thrown when arg_params or aux_params
* contain extra parameters that is not needed by the executor.
- * @throws IllegalArgumentException If there is additional parameters in the dict but allow_extra_params=False
+ * @throws IllegalArgumentException
+ * If there is additional parameters in the dict but allow_extra_params=False
*/
def copyParamsFrom(argParams: Map[String, NDArray],
auxParams: Map[String, NDArray],
@@ -229,3 +232,4 @@ class Executor(val handle: ExecutorHandle, val symbol: Symbol) {
str.value
}
}
+// scalastyle:on finalize
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
index 11dcfadbbcff..9b999e0d6b32 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/IO.scala
@@ -12,19 +12,19 @@ object IO {
private val iterCreateFuncs: Map[String, IterCreateFunc] = _initIOModule()
/**
- * create iterator via iterName and params
- * @param iterName name of iterator; "MNISTIter" or "ImageRecordIter"
- * @param params paramters for create iterator
- * @return
- */
+ * create iterator via iterName and params
+ * @param iterName name of iterator; "MNISTIter" or "ImageRecordIter"
+ * @param params paramters for create iterator
+ * @return
+ */
def createIterator(iterName: String, params: Map[String, String]): DataIter = {
- return iterCreateFuncs(iterName)(params)
+ iterCreateFuncs(iterName)(params)
}
/**
- * initi all IO creator Functions
- * @return
- */
+ * initi all IO creator Functions
+ * @return
+ */
private def _initIOModule(): Map[String, IterCreateFunc] = {
val IterCreators = new ListBuffer[DataIterCreator]
checkCall(_LIB.mxListDataIters(IterCreators))
@@ -41,83 +41,84 @@ object IO {
val paramStr = Base.ctypes2docstring(argNames, argTypes, argDescs)
val docStr = s"${name.value}\n${desc.value}\n\n$paramStr\n"
logger.debug(docStr)
- return (name.value, creator(handle))
+ (name.value, creator(handle))
}
/**
- *
- * @param handle
- * @param params
- * @return
- */
+ *
+ * @param handle
+ * @param params
+ * @return
+ */
private def creator(handle: DataIterCreator)(
params: Map[String, String]): DataIter = {
val out = new DataIterHandle
val keys = params.keys.toArray
val vals = params.values.toArray
checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out))
- return new MXDataIter(out)
+ new MXDataIter(out)
}
}
/**
- * class batch of data
- * @param data
- * @param label
- * @param index
- * @param pad
- */
-case class DataBatch(val data: NDArray,
- val label: NDArray,
- val index: List[Long],
- val pad: Int)
+ * class batch of data
+ * @param data
+ * @param label
+ * @param index
+ * @param pad
+ */
+case class DataBatch(data: NDArray,
+ label: NDArray,
+ index: List[Long],
+ pad: Int)
/**
- *DataIter object in mxnet.
- */
+ * DataIter object in mxnet.
+ */
abstract class DataIter (val batchSize: Int = 0) {
/**
- * reset the iterator
- */
+ * reset the iterator
+ */
def reset(): Unit
+
/**
- * Iterate to next batch
- * @return whether the move is successful
- */
+ * Iterate to next batch
+ * @return whether the move is successful
+ */
def iterNext(): Boolean
/**
- * get next data batch from iterator
- * @return
- */
+ * get next data batch from iterator
+ * @return
+ */
def next(): DataBatch = {
- return new DataBatch(getData(), getLabel(), getIndex(), getPad())
+ new DataBatch(getData(), getLabel(), getIndex(), getPad())
}
/**
- * get data of current batch
- * @return the data of current batch
- */
+ * get data of current batch
+ * @return the data of current batch
+ */
def getData(): NDArray
/**
- * Get label of current batch
- * @return the label of current batch
- */
+ * Get label of current batch
+ * @return the label of current batch
+ */
def getLabel(): NDArray
/**
- * get the number of padding examples
- * in current batch
- * @return number of padding examples in current batch
- */
+ * get the number of padding examples
+ * in current batch
+ * @return number of padding examples in current batch
+ */
def getPad(): Int
/**
- * the index of current batch
- * @return
- */
+ * the index of current batch
+ * @return
+ */
def getIndex(): List[Long]
}
@@ -126,112 +127,114 @@ abstract class DataIter (val batchSize: Int = 0) {
* DataIter built in MXNet.
* @param handle the handle to the underlying C++ Data Iterator
*/
+// scalastyle:off finalize
class MXDataIter(val handle: DataIterHandle) extends DataIter {
private val logger = LoggerFactory.getLogger(classOf[MXDataIter])
- override def finalize() = {
+ override def finalize(): Unit = {
checkCall(_LIB.mxDataIterFree(handle))
}
/**
- * reset the iterator
- */
+ * reset the iterator
+ */
override def reset(): Unit = {
checkCall(_LIB.mxDataIterBeforeFirst(handle))
}
/**
- * Iterate to next batch
- * @return whether the move is successful
- */
+ * Iterate to next batch
+ * @return whether the move is successful
+ */
override def iterNext(): Boolean = {
val next = new RefInt
checkCall(_LIB.mxDataIterNext(handle, next))
- return next.value > 0
+ next.value > 0
}
/**
- * get data of current batch
- * @return the data of current batch
- */
+ * get data of current batch
+ * @return the data of current batch
+ */
override def getData(): NDArray = {
val out = new NDArrayHandle
checkCall(_LIB.mxDataIterGetData(handle, out))
- return new NDArray(out, writable = false)
+ new NDArray(out, writable = false)
}
/**
- * Get label of current batch
- * @return the label of current batch
- */
+ * Get label of current batch
+ * @return the label of current batch
+ */
override def getLabel(): NDArray = {
val out = new NDArrayHandle
checkCall(_LIB.mxDataIterGetLabel(handle, out))
- return new NDArray(out, writable = false)
+ new NDArray(out, writable = false)
}
/**
- * the index of current batch
- * @return
- */
+ * the index of current batch
+ * @return
+ */
override def getIndex(): List[Long] = {
val outIndex = new ListBuffer[Long]
val outSize = new RefLong
checkCall(_LIB.mxDataIterGetIndex(handle, outIndex, outSize))
- return outIndex.toList
+ outIndex.toList
}
/**
- * get the number of padding examples
- * in current batch
- * @return number of padding examples in current batch
- */
+ * get the number of padding examples
+ * in current batch
+ * @return number of padding examples in current batch
+ */
override def getPad(): MXUint = {
val out = new MXUintRef
checkCall(_LIB.mxDataIterGetPadNum(handle, out))
- return out.value
+ out.value
}
}
+// scalastyle:on finalize
/**
- * To do
- */
+ * TODO
+ */
class ArrayDataIter() extends DataIter {
/**
- * reset the iterator
- */
+ * reset the iterator
+ */
override def reset(): Unit = ???
/**
- * get data of current batch
- * @return the data of current batch
- */
+ * get data of current batch
+ * @return the data of current batch
+ */
override def getData(): NDArray = ???
/**
- * Get label of current batch
- * @return the label of current batch
- */
+ * Get label of current batch
+ * @return the label of current batch
+ */
override def getLabel(): NDArray = ???
/**
- * the index of current batch
- * @return
- */
+ * the index of current batch
+ * @return
+ */
override def getIndex(): List[Long] = ???
/**
- * Iterate to next batch
- * @return whether the move is successful
- */
+ * Iterate to next batch
+ * @return whether the move is successful
+ */
override def iterNext(): Boolean = ???
/**
- * get the number of padding examples
- * in current batch
- * @return number of padding examples in current batch
- */
- override def getPad(): MXUint = ???
+ * get the number of padding examples
+ * in current batch
+ * @return number of padding examples in current batch
+ */
+ override def getPad(): Int = ???
}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala
index 5287200065e2..e1952236db94 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala
@@ -1,6 +1,6 @@
package ml.dmlc.mxnet
-import ml.dmlc.mxnet.NDArray.{array, zeros, ones}
+import ml.dmlc.mxnet.NDArray.array
/**
@@ -80,9 +80,9 @@ abstract class Initializer {
*
* @param scale The scale of uniform distribution
*/
-class Uniform(protected val scale: Float=0.07f) extends Initializer {
+class Uniform(protected val scale: Float = 0.07f) extends Initializer {
override def _initWeight(name: String, arr: NDArray): Unit = {
- Random.uniform(-scale, scale, out=arr)
+ Random.uniform(-scale, scale, out = arr)
}
}
@@ -92,9 +92,9 @@ class Uniform(protected val scale: Float=0.07f) extends Initializer {
*
* @param sigma Standard deviation for gaussian distribution.
*/
-class Normal(protected val sigma: Float=0.01f) extends Initializer {
+class Normal(protected val sigma: Float = 0.01f) extends Initializer {
override def _initWeight(name: String, arr: NDArray): Unit = {
- Random.normal(0, sigma, out=arr)
+ Random.normal(0, sigma, out = arr)
}
}
@@ -106,8 +106,8 @@ class Normal(protected val sigma: Float=0.01f) extends Initializer {
* @param factorType Options are: "avg", "in", "out"
* @param magnitude scale of random number range
*/
-class Xavier(protected val rndType: String ="uniform",
- protected val factorType: String ="avg",
+class Xavier(protected val rndType: String = "uniform",
+ protected val factorType: String = "avg",
protected val magnitude: Int = 3) extends Initializer {
override def _initWeight(name: String, arr: NDArray): Unit = {
@@ -125,9 +125,9 @@ class Xavier(protected val rndType: String ="uniform",
val scale = math.sqrt(magnitude / factor).toFloat
rndType match {
- case "uniform" => Random.uniform(-scale, scale, out=arr)
- case "normal" => Random.normal(0, scale, out=arr)
+ case "uniform" => Random.uniform(-scale, scale, out = arr)
+ case "normal" => Random.normal(0, scale, out = arr)
case _ => throw new IllegalArgumentException("Unknown random type")
}
}
-}
\ No newline at end of file
+}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
index cdc69d857c8f..79e84b598c15 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
@@ -85,7 +85,7 @@ class LibInfo {
@native def mxKVStoreGetGroupSize(handle: KVStoreHandle, size: RefInt): Int
@native def mxKVStoreGetRank(handle: KVStoreHandle, size: RefInt): Int
- //DataIter Funcs
+ // DataIter Funcs
@native def mxListDataIters(handles: ListBuffer[DataIterCreator]): Int
@native def mxDataIterCreateIter(handle: DataIterCreator,
keys: Array[String],
@@ -109,7 +109,7 @@ class LibInfo {
outSize: RefLong): Int
@native def mxDataIterGetPadNum(handle: DataIterHandle,
out: MXUintRef): Int
- //Executors
+ // Executors
@native def mxExecutorOutputs(handle: ExecutorHandle, outputs: ArrayBuffer[NDArrayHandle]): Int
@native def mxExecutorFree(handle: ExecutorHandle): Int
@native def mxExecutorForward(handle: ExecutorHandle, isTrain: Int): Int
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
index 4277bc08c654..1912832341a3 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
@@ -16,14 +16,16 @@ object Model {
* @param maxSize max size of the kvstore
* @return Option of created [[KVStore]] and whether or not update weight on it
*/
- private def createKVStore(kvStore: String, numDevice: Int, maxSize: Int): (Option[KVStore], Boolean) = {
+ private def createKVStore(kvStore: String,
+ numDevice: Int,
+ maxSize: Int): (Option[KVStore], Boolean) = {
if (numDevice == 1 && !kvStore.contains("dist")) {
// no need to use kv for single device and single machine
(None, false)
} else {
var kvType = kvStore
if (kvType == "local") {
- //automatically select a proper local
+ // automatically select a proper local
kvType =
if (maxSize < 1024 * 1024 * 16) {
"local_update_cpu"
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala
index bc48a0514af1..5ef6c8eaec81 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Monitor.scala
@@ -28,13 +28,13 @@ class Monitor(protected val interval: Int, protected var statFunc: (NDArray) =>
private var activated: Boolean = false
private var queue = new mutable.Queue[(Int, String, NDArray)]
private var step: Int = 0
- private var exes = new mutable.Queue[Executor]
+ private var exes = new mutable.Queue[Executor]
val statHelper: MXMonitorCallback = new MXMonitorCallback {
override def invoke(name: String, arr: NDArrayHandle): Unit = {
// wrapper for executor callback
if (activated) {
- val array = new NDArray(arr, writable=false)
+ val array = new NDArray(arr, writable = false)
val elem = (step, name, statFunc(array))
queue += elem
}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
index 0429313b6db1..cb09451a10da 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
@@ -39,7 +39,9 @@ object NDArray {
}
// internal NDArray function
- private[mxnet] def _unaryNDArrayFunction(funcName: String, src: NDArray, out: NDArray = null): NDArray = {
+ private[mxnet] def _unaryNDArrayFunction(funcName: String,
+ src: NDArray,
+ out: NDArray = null): NDArray = {
var output = out
val function = functions(funcName)
require(function != null, s"invalid function name $funcName")
@@ -112,7 +114,9 @@ object NDArray {
*
* @return a new empty ndarray handle
*/
- private def _newAllocHandle(shape: Array[Int], ctx: Context, delayAlloc: Boolean): NDArrayHandle = {
+ private def _newAllocHandle(shape: Array[Int],
+ ctx: Context,
+ delayAlloc: Boolean): NDArrayHandle = {
val hdl = new NDArrayHandle
checkCall(_LIB.mxNDArrayCreate(
shape,
@@ -169,7 +173,11 @@ object NDArray {
} else if (nMutateVars.value == 1 && nUsedVars.value == 1 && nScalars.value == 0) {
(name.value, UnaryNDArrayFunction(handle, acceptEmptyMutate))
} else {
- (name.value, GenericNDArrayFunction(handle, acceptEmptyMutate, nMutateVars.value, useVarsRange, scalarRange))
+ (name.value, GenericNDArrayFunction(handle,
+ acceptEmptyMutate,
+ nMutateVars.value,
+ useVarsRange,
+ scalarRange))
}
}
@@ -196,7 +204,7 @@ object NDArray {
*
* @return The created NDArray.
*/
- def empty(shape: Array[Int], ctx: Context=null): NDArray = {
+ def empty(shape: Array[Int], ctx: Context = null): NDArray = {
val context = if (ctx == null) Context.defaultCtx else ctx
new NDArray(handle = NDArray._newAllocHandle(shape, context, delayAlloc = false))
}
@@ -213,7 +221,7 @@ object NDArray {
*
* @return The created NDArray.
*/
- def zeros(shape: Array[Int], ctx: Context=null): NDArray = {
+ def zeros(shape: Array[Int], ctx: Context = null): NDArray = {
val arr = empty(shape, ctx)
arr.set(0f)
arr
@@ -229,7 +237,7 @@ object NDArray {
* @param ctx The context of the NDArray, default to current default context.
* @return The created NDArray.
*/
- def ones(shape: Array[Int], ctx: Context=null): NDArray = {
+ def ones(shape: Array[Int], ctx: Context = null): NDArray = {
val arr = empty(shape, ctx)
arr.set(1f)
arr
@@ -269,9 +277,9 @@ object NDArray {
}
// TODO
- def _randomUniform(low: Float, high: Float, out: NDArray) = ???
+ def _randomUniform(low: Float, high: Float, out: NDArray): NDArray = ???
- def _randomGaussian(mean: Float, stdvar: Float, out: NDArray) = ???
+ def _randomGaussian(mean: Float, stdvar: Float, out: NDArray): NDArray = ???
/**
@@ -280,7 +288,7 @@ object NDArray {
* @param ctx The context of the NDArray, default to current default context.
* @return The created NDArray.
*/
- def array(sourceArr: Array[Float], ctx: Context=null): NDArray = ???
+ def array(sourceArr: Array[Float], ctx: Context = null): NDArray = ???
/**
* Load ndarray from binary file.
@@ -323,8 +331,9 @@ object NDArray {
* NDArray object in mxnet.
* NDArray is basic ndarray/Tensor like data structure in mxnet.
*/
+// scalastyle:off finalize
class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
- override def finalize() = {
+ override def finalize(): Unit = {
checkCall(_LIB.mxNDArrayFree(handle))
}
@@ -377,7 +386,7 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
*/
def set(value: Float): NDArray = {
require(writable, "trying to assign to a readonly NDArray")
- NDArray._genericNDArrayFunction("_set_value", Array[Any](value), out=Array(this))
+ NDArray._genericNDArrayFunction("_set_value", Array[Any](value), out = Array(this))
this
}
@@ -404,14 +413,14 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
if (!writable) {
throw new IllegalArgumentException("trying to add to a readonly NDArray")
}
- NDArray._binaryNDArrayFunction("_plus", this, other, out=this)
+ NDArray._binaryNDArrayFunction("_plus", this, other, out = this)
}
def +=(other: Float): NDArray = {
if (!writable) {
throw new IllegalArgumentException("trying to add to a readonly NDArray")
}
- NDArray._genericNDArrayFunction("_plus_scalar", Array[Any](this, other), out=Array(this))
+ NDArray._genericNDArrayFunction("_plus_scalar", Array[Any](this, other), out = Array(this))
this
}
@@ -427,14 +436,14 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
if (!writable) {
throw new IllegalArgumentException("trying to subtract from a readonly NDArray")
}
- NDArray._binaryNDArrayFunction("_minus", this, other, out=this)
+ NDArray._binaryNDArrayFunction("_minus", this, other, out = this)
}
def -=(other: Float): NDArray = {
if (!writable) {
throw new IllegalArgumentException("trying to subtract from a readonly NDArray")
}
- NDArray._genericNDArrayFunction("_minus_scalar", Array[Any](this, other), out=Array(this))
+ NDArray._genericNDArrayFunction("_minus_scalar", Array[Any](this, other), out = Array(this))
this
}
@@ -450,18 +459,18 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
NDArray._genericNDArrayFunction("_mul_scalar", Array[Any](this, -1f))(0)
}
- def *=(other: NDArray) = {
+ def *=(other: NDArray): NDArray = {
if (!writable) {
throw new IllegalArgumentException("trying to multiply to a readonly NDArray")
}
- NDArray._binaryNDArrayFunction("_mul", this, other, out=this)
+ NDArray._binaryNDArrayFunction("_mul", this, other, out = this)
}
- def *=(other: Float) = {
+ def *=(other: Float): NDArray = {
if (!writable) {
throw new IllegalArgumentException("trying to multiply to a readonly NDArray")
}
- NDArray._genericNDArrayFunction("_mul_scalar", Array[Any](this, other), out=Array(this))
+ NDArray._genericNDArrayFunction("_mul_scalar", Array[Any](this, other), out = Array(this))
this
}
@@ -477,14 +486,14 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
if (!writable) {
throw new IllegalArgumentException("trying to divide from a readonly NDArray")
}
- NDArray._binaryNDArrayFunction("_div", this, other, out=this)
+ NDArray._binaryNDArrayFunction("_div", this, other, out = this)
}
def /=(other: Float): NDArray = {
if (!writable) {
throw new IllegalArgumentException("trying to divide from a readonly NDArray")
}
- NDArray._genericNDArrayFunction("_div_scalar", Array[Any](this, other), out=Array(this))
+ NDArray._genericNDArrayFunction("_div_scalar", Array[Any](this, other), out = Array(this))
this
}
@@ -541,6 +550,7 @@ class NDArray(val handle: NDArrayHandle, val writable: Boolean = true) {
// Get size of current NDArray.
def size: Int = shape.product
}
+// scalastyle:on finalize
object NDArrayConversions {
implicit def int2Scalar(x: Int): NDArrayConversions = new NDArrayConversions(x.toFloat)
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala
index 9bcfe4eb162b..b7abf9627429 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala
@@ -1,7 +1,6 @@
package ml.dmlc.mxnet
-import ml.dmlc.mxnet.Base._
-import ml.dmlc.mxnet.NDArray.{_randomUniform, _randomGaussian, empty}
+import ml.dmlc.mxnet.NDArray.{_randomGaussian, _randomUniform, empty}
/**
* Random Number interface of mxnet.
@@ -18,7 +17,11 @@ object Random {
* @param out Output place holder
* @return The result NDArray with generated result.
*/
- def uniform(low: Float, high: Float, shape: Array[Int]=null, ctx: Context=null, out: NDArray=null): NDArray = {
+ def uniform(low: Float,
+ high: Float,
+ shape: Array[Int] = null,
+ ctx: Context = null,
+ out: NDArray = null): NDArray = {
var outCopy = out
if (outCopy != null) {
require(shape == null & ctx == null, "shape and ctx is not needed when out is specified.")
@@ -26,7 +29,7 @@ object Random {
require(shape != null, "shape is required when out is not specified")
outCopy = empty(shape, ctx)
}
- return _randomUniform(low, high, outCopy)
+ _randomUniform(low, high, outCopy)
}
@@ -40,7 +43,11 @@ object Random {
* @param out Output place holder
* @return The result NDArray with generated result.
*/
- def normal(mean: Float, stdvar: Float, shape: Array[Int]=null, ctx: Context=null, out: NDArray=null): NDArray = {
+ def normal(mean: Float,
+ stdvar: Float,
+ shape: Array[Int] = null,
+ ctx: Context = null,
+ out: NDArray = null): NDArray = {
var outCopy = out
if (outCopy != null) {
require(shape == null & ctx == null, "shape and ctx is not needed when out is specified.")
@@ -48,7 +55,7 @@ object Random {
require(shape != null, "shape is required when out is not specified")
outCopy = empty(shape, ctx)
}
- return _randomGaussian(mean, stdvar, outCopy)
+ _randomGaussian(mean, stdvar, outCopy)
}
@@ -64,7 +71,7 @@ object Random {
* This means if you set the same seed, the random number sequence
* generated from GPU0 can be different from CPU.
*/
- def seed(seedState: Int) = {
+ def seed(seedState: Int): Unit = {
// TODO
// checkCall(_LIB.mxRandomSeed(seedState))
}
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala
index fc6940cec50b..2991743da008 100644
--- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/IOSuite.scala
@@ -6,7 +6,7 @@ import scala.sys.process._
class IOSuite extends FunSuite with BeforeAndAfterAll {
test("test MNISTIter") {
- //get data
+ // get data
"./scripts/get_mnist_data.sh" !
val params = Map(
@@ -21,7 +21,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
)
val mnistIter = IO.createIterator("MNISTIter", params)
- //test_loop
+ // test_loop
mnistIter.reset()
val nBatch = 600
var batchCount = 0
@@ -29,9 +29,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
val batch = mnistIter.next()
batchCount+=1
}
- //test loop
+ // test loop
assert(nBatch === batchCount)
- //test reset
+ // test reset
mnistIter.reset()
mnistIter.iterNext()
val label0 = mnistIter.getLabel().toArray
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
index b237500281c8..145ba4d0b71b 100644
--- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
@@ -28,7 +28,9 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll {
val kv = KVStore.create()
val updater = new MXKVStoreUpdater {
override def update(key: Int, input: NDArray, stored: NDArray, handle: AnyRef): Unit = {
+ // scalastyle:off println
println(s"update on key $key")
+ // scalastyle:on println
stored += input * 2
}
}
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index 25c6b9cdcf1b..524c9a52eb0f 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -142,8 +142,8 @@
false
${basedir}/src/main/scala
${basedir}/src/test/scala
- ${basedir}/src/main/resources/scalastyle_config.xml
- ${project.basedir}/scalastyle_output.xml
+ ${basedir}/scalastyle-config.xml
+ ${basedir}/target/scalastyle-output.xml
UTF-8
diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh
index 43ed4cc922a0..cd91c0db89bc 100755
--- a/tests/travis/run_test.sh
+++ b/tests/travis/run_test.sh
@@ -111,7 +111,6 @@ if [ ${TASK} == "python_test" ]; then
exit 0
fi
-
if [ ${TASK} == "scala_test" ]; then
if [ ${TRAVIS_OS_NAME} == "osx" ]; then
LIB_GOMP_PATH=`find /usr/local/lib -name libgomp.dylib | grep -v i386 | head -n1`
@@ -141,3 +140,10 @@ if [ ${TASK} == "scala_test" ]; then
exit 0
fi
+
+if [ ${TASK} == "scala_lint" ]; then
+ export JAVA_HOME=$(/usr/libexec/java_home)
+ cd scala-package/core
+ mvn scalastyle:check || exit -1
+ exit 0
+fi