diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml
index e65f9e761598..c0b2bfa696ea 100644
--- a/scala-package/core/pom.xml
+++ b/scala-package/core/pom.xml
@@ -85,5 +85,9 @@
org.scalatest
scalatest_${scala.binary.version}
+
+ org.scalacheck
+ scalacheck_${scala.binary.version}
+
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
index 206b8d13f108..76b5d1a3df17 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala
@@ -11,7 +11,7 @@ object Base {
type MXFloat = Float
type CPtrAddress = Long
// TODO: make it more friendly to java
- type Shape = Array[Int]
+ type Shape = Vector[Int]
type NDArrayHandle = CPtrAddress
type FunctionHandle = CPtrAddress
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala
index cdacfa176719..08e390bf4f2c 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala
@@ -3,7 +3,9 @@ package ml.dmlc.mxnet
object Context {
val devtype2str = Map(1 -> "cpu", 2 -> "gpu", 3 -> "cpu_pinned")
val devstr2type = Map("cpu" -> 1, "gpu" -> 2, "cpu_pinned" -> 3)
- val defaultCtx = new Context("cpu", 0)
+ private var _defaultCtx = new Context("cpu", 0)
+
+ def defaultCtx: Context = _defaultCtx
def cpu(deviceId: Int = 0): Context = {
new Context("cpu", deviceId)
@@ -13,6 +15,16 @@ object Context {
new Context("gpu", deviceId)
}
+ def withScope[T](device: Context)(body: => T): T = {
+ val oldDefaultCtx = Context.defaultCtx
+ Context._defaultCtx = device
+ try {
+ body
+ } finally {
+ Context._defaultCtx = oldDefaultCtx
+ }
+ }
+
implicit def ctx2Array(ctx: Context): Array[Context] = Array(ctx)
}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala
index 1715aae5549f..72d5ad3623ea 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala
@@ -294,7 +294,7 @@ class DataParallelExecutorManager(symbol: Symbol,
ctx.zipWithIndex.map { case (context, i) =>
val dataShapes =
trainData.provideData.map { case (name: String, shape: Shape) =>
- (name, Array(slices(i)._2 - slices(i)._1) ++ shape.drop(1))
+ (name, Vector(slices(i)._2 - slices(i)._1) ++ shape.drop(1))
}
symbol.simpleBind(context, "write", shapeDict = dataShapes)
}
@@ -334,7 +334,7 @@ class DataParallelExecutorManager(symbol: Symbol,
}.toArray
private val batchSize = trainData.batchSize
private val outputShapes: Array[Shape] = trainExecs(0).outputs.map { x: NDArray =>
- Array(batchSize) ++ x.shape.drop(1)
+ Vector(batchSize) ++ x.shape.drop(1)
}
private[mxnet] val cpuOutputArrays = outputShapes.map(NDArray.zeros(_))
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
index 96d80f1d4ca7..c49f8ce9d1a8 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala
@@ -164,9 +164,9 @@ class LibInfo {
keys: Array[String],
argIndPtr: Array[MXUint],
argShapeData: Array[MXUint],
- inShapeData: ListBuffer[Shape],
- outShapeData: ListBuffer[Shape],
- auxShapeData: ListBuffer[Shape],
+ inShapeData: ListBuffer[Array[Int]],
+ outShapeData: ListBuffer[Array[Int]],
+ auxShapeData: ListBuffer[Array[Int]],
complete: RefInt): Int
@native def mxSymbolGetOutput(handle: SymbolHandle, index: Int, out: SymbolHandleRef): Int
@native def mxSymbolSaveToJSON(handle: SymbolHandle, out: RefString): Int
@@ -185,4 +185,7 @@ class LibInfo {
auxArgsHandle: Array[NDArrayHandle],
out: ExecutorHandleRef): Int
// scalastyle:on parameterNum
+
+ // Random
+ @native def mxRandomSeed(seed: Int): Int
}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
index 1d4642f3cbf7..e99dc30c12f8 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala
@@ -426,6 +426,7 @@ class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(Context.cp
}
batch = data.next()
}
+ // TODO: we can use Symbol.concat to do the same thing. Can it be more efficient?
outputs.map(NDArray.concatenate(_))
}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
index 20682b93cb4c..8ad6d32248d3 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
@@ -125,12 +125,12 @@ object NDArray {
*
* @return a new empty ndarray handle
*/
- private def newAllocHandle(shape: Array[Int],
+ private def newAllocHandle(shape: Shape,
ctx: Context,
delayAlloc: Boolean): NDArrayHandle = {
val hdl = new NDArrayHandleRef
checkCall(_LIB.mxNDArrayCreate(
- shape,
+ shape.toArray,
shape.length,
ctx.deviceTypeid,
ctx.deviceId,
@@ -217,14 +217,14 @@ object NDArray {
*
* @return The created NDArray.
*/
- def empty(shape: Array[Int], ctx: Context = null): NDArray = {
+ def empty(shape: Shape, ctx: Context = null): NDArray = {
val context = if (ctx == null) Context.defaultCtx else ctx
new NDArray(handle = NDArray.newAllocHandle(shape, context, delayAlloc = false))
}
- def empty(shape: Int *): NDArray = empty(shape.toArray)
+ def empty(shape: Int *): NDArray = empty(shape.toVector)
- def empty(ctx: Context, shape: Int *): NDArray = empty(shape.toArray, ctx)
+ def empty(ctx: Context, shape: Int *): NDArray = empty(shape.toVector, ctx)
/**
* Create a new NDArray filled with 0, with specified shape.
@@ -234,15 +234,15 @@ object NDArray {
*
* @return The created NDArray.
*/
- def zeros(shape: Array[Int], ctx: Context = null): NDArray = {
+ def zeros(shape: Shape, ctx: Context = null): NDArray = {
val arr = empty(shape, ctx)
arr.set(0f)
arr
}
- def zeros(shape: Int *): NDArray = zeros(shape.toArray)
+ def zeros(shape: Int *): NDArray = zeros(shape.toVector)
- def zeros(ctx: Context, shape: Int *): NDArray = zeros(shape.toArray, ctx)
+ def zeros(ctx: Context, shape: Int *): NDArray = zeros(shape.toVector, ctx)
/**
* Create a new NDArray filled with 1, with specified shape.
@@ -250,15 +250,15 @@ object NDArray {
* @param ctx The context of the NDArray, default to current default context.
* @return The created NDArray.
*/
- def ones(shape: Array[Int], ctx: Context = null): NDArray = {
+ def ones(shape: Shape, ctx: Context = null): NDArray = {
val arr = empty(shape, ctx)
arr.set(1f)
arr
}
- def ones(shape: Int *): NDArray = ones(shape.toArray)
+ def ones(shape: Int *): NDArray = ones(shape.toVector)
- def ones(ctx: Context, shape: Int *): NDArray = ones(shape.toArray, ctx)
+ def ones(ctx: Context, shape: Int *): NDArray = ones(shape.toVector, ctx)
/**
* Clip ndarray elements to range (from, to)
@@ -460,32 +460,29 @@ object NDArray {
* @param ctx The context of the NDArray, default to current default context.
* @return The created NDArray.
*/
- def array(sourceArr: Array[Float], shape: Array[Int], ctx: Context = null): NDArray = {
+ def array(sourceArr: Array[Float], shape: Shape, ctx: Context = null): NDArray = {
val arr = empty(shape, ctx)
arr.set(sourceArr)
arr
}
/**
- * Join a sequence of arrays at the first axis
+ * Join a sequence of arrays at axis-0
* TODO: shall we make it native?
* @param arrays
*/
def concatenate(arrays: Seq[NDArray], ctx: Context = null): NDArray = {
require(arrays != null && arrays.size > 0, "arrays empty")
val array0 = arrays.head
- val shape = Array.ofDim[Int](array0.shape.length)
- array0.shape.copyToArray(shape)
- var axis0 = shape(0)
- val shapeRemain = shape.drop(1)
+ val shape = array0.shape.drop(1)
+ var axis0 = array0.shape(0)
arrays.drop(1).foreach { array =>
- require(shapeRemain.sameElements(array.shape.drop(1)),
+ require(shape.sameElements(array.shape.drop(1)),
s"shape mismatch between (${array.shape.mkString(",")}) and (${shape.mkString(",")})")
axis0 += array.shape(0)
}
- shape(0) = axis0
- val output = NDArray.empty(shape, ctx)
+ val output = NDArray.empty(Vector(axis0) ++ shape, ctx)
axis0 = 0
arrays.foreach { array =>
output.slice(axis0, axis0 + array.shape(0)).set(array)
@@ -604,8 +601,14 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean =
slice(range._1, range._2)
}
- def slice(start: Int): NDArray = {
- slice(start, shape(0))
+ /**
+ * Return a sliced NDArray at the ith position of axis0
+ * NDArray only support continuous slicing on axis 0
+ * @param i
+ * @return a sliced NDArray that shares memory with current one.
+ */
+ def slice(i: Int): NDArray = {
+ slice(i, i + 1)
}
/**
@@ -804,16 +807,28 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean =
* Get shape of current NDArray.
* @return an array representing shape of current ndarray
*/
- def shape: Array[Int] = {
+ def shape: Shape = {
val ndim = new MXUintRef
val data = ArrayBuffer[Int]()
checkCall(_LIB.mxNDArrayGetShape(handle, ndim, data))
require(ndim.value == data.length, s"ndim=$ndim, while len(pdata)=${data.length}")
- data.toArray
+ data.toVector
}
// Get size of current NDArray.
def size: Int = shape.product
+
+ override def equals(o: Any): Boolean = o match {
+ case that: NDArray => {
+ that.shape == this.shape && that.toArray.sameElements(this.toArray)
+ }
+ case _ => false
+ }
+
+ override def hashCode: Int = {
+ // TODO: naive implementation
+ shape.hashCode + toArray.hashCode
+ }
}
// scalastyle:on finalize
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala
index d1554d29c5ac..e1279e095dfa 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala
@@ -1,6 +1,6 @@
package ml.dmlc.mxnet
-import ml.dmlc.mxnet.Base.Shape
+import ml.dmlc.mxnet.Base._
import ml.dmlc.mxnet.NDArray.{randomGaussian, randomUniform, empty}
/**
@@ -73,7 +73,6 @@ object Random {
* generated from GPU0 can be different from CPU.
*/
def seed(seedState: Int): Unit = {
- // TODO
-// checkCall(_LIB.mxRandomSeed(seedState))
+ checkCall(_LIB.mxRandomSeed(seedState))
}
}
diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala
index e58e6abaae39..43c382823955 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala
@@ -202,15 +202,15 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
def inferShape(keys: Array[String], indPtr: Array[Int], values: Array[Int])
: (Seq[Shape], Seq[Shape], Seq[Shape]) = {
- val argShapeData = ListBuffer.empty[Shape]
- val outShapeData = ListBuffer.empty[Shape]
- val auxShapeData = ListBuffer.empty[Shape]
+ val argShapeData = ListBuffer.empty[Array[Int]]
+ val outShapeData = ListBuffer.empty[Array[Int]]
+ val auxShapeData = ListBuffer.empty[Array[Int]]
val complete = new RefInt
checkCall(_LIB.mxSymbolInferShape(handle, indPtr.size - 1, keys, indPtr, values,
argShapeData, outShapeData, auxShapeData, complete))
if (complete.value != 0) {
- (argShapeData, outShapeData, auxShapeData)
+ (argShapeData.map(_.toVector), outShapeData.map(_.toVector), auxShapeData.map(_.toVector))
} else {
(null, null, null)
}
@@ -555,6 +555,10 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
bind(ctx, args, argsGrad, "write", Nil, null)
}
+ def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray]): Executor = {
+ bind(ctx, args, argsGrad, "write", Nil, null)
+ }
+
def bind(ctx: Context, args: Seq[NDArray]): Executor = {
val symbolArguments = listArguments()
bindHelper(ctx, symbolArguments, args, null,
@@ -670,6 +674,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) {
}
object Symbol {
+ private type SymbolCreateFunc = Map[String, Any] => Symbol
private val logger = LoggerFactory.getLogger(classOf[Symbol])
private val functions: Map[String, SymbolFunction] = initSymbolModule()
private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3)
@@ -688,23 +693,23 @@ object Symbol {
sym
}
- def FullyConnected: Map[String, Any] => Symbol = {
+ def FullyConnected: SymbolCreateFunc = {
FullyConnected(null)
}
- def FullyConnected(attr: Map[String, String]): Map[String, Any] => Symbol = {
+ def FullyConnected(attr: Map[String, String]): SymbolCreateFunc = {
createNoCheck("FullyConnected", attr)
}
- def Activation: Map[String, Any] => Symbol = {
+ def Activation: SymbolCreateFunc = {
Activation(null)
}
- def Activation(attr: Map[String, String]): Map[String, Any] => Symbol = {
+ def Activation(attr: Map[String, String]): SymbolCreateFunc = {
createNoCheck("Activation", attr)
}
- def Convolution(attr: Map[String, String]): Map[String, Any] => Symbol = {
+ def Convolution(attr: Map[String, String]): SymbolCreateFunc = {
createNoCheck("Convolution", attr)
}
@@ -732,6 +737,43 @@ object Symbol {
createNoCheck("Cast")
}
+ def ElementWiseSum(name: String, inputs: Symbol *): Symbol = {
+ create("ElementWiseSum", inputs.toArray, Map("name" -> name), null)
+ }
+
+ def ElementWiseSum(inputs: Seq[Symbol], name: String): Symbol = {
+ create("ElementWiseSum", inputs.toArray, Map("name" -> name), null)
+ }
+
+ def Concat(inputs: Seq[Symbol],
+ paramKwargs: Map[String, Any],
+ attr: Map[String, String] = null): Symbol = {
+ create("Concat", inputs.toArray,
+ paramKwargs.map { case (k, v) => (k, v.toString) }, attr)
+ }
+
+ // Use Logistic regression for final output, this is used on final output of a net.
+ // Logistic regression is suitable for binary classification or probability prediction tasks.
+ def LogisticRegressionOutput(inputs: Seq[Symbol], attr: Map[String, String] = null): Symbol = {
+ create("LogisticRegressionOutput", inputs.toArray, null, attr)
+ }
+
+ // Use linear regression for final output, this is used on final output of a net.
+ def LinearRegressionOutput(inputs: Seq[Symbol], attr: Map[String, String] = null): Symbol = {
+ create("LinearRegressionOutput", inputs.toArray, null, attr)
+ }
+
+ /**
+ * Apply swapaxis to input.
+ * @param data Input data to the SwapAxisOp.
+ * @param dim1 (non-negative), default=0, the first axis to be swapped.
+ * @param dim2 (non-negative), default=0, the second axis to be swapped.
+ */
+ def SwapAxis(data: Symbol, dim1: Int = 0, dim2: Int = 0,
+ attr: Map[String, String] = null): Symbol = {
+ createNoCheck("SwapAxis")(Map("data" -> data, "dim1" -> dim1, "dim2" -> dim2))
+ }
+
/**
* Create a symbol that groups symbols together.
* @param symbols List of symbols to be grouped.
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/CheckUtils.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/CheckUtils.scala
new file mode 100644
index 000000000000..2fc2f02d09a2
--- /dev/null
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/CheckUtils.scala
@@ -0,0 +1,16 @@
+package ml.dmlc.mxnet
+
+object CheckUtils {
+ def reldiff(a: NDArray, b: NDArray): Float = {
+ val diff = NDArray.sum(NDArray.abs(a - b)).toScalar
+ val norm = NDArray.sum(NDArray.abs(a)).toScalar
+ diff / norm
+ }
+
+ def reldiff(a: Array[Float], b: Array[Float]): Float = {
+ val diff =
+ (a zip b).map { case (aElem, bElem) => Math.abs(aElem - bElem) }.sum
+ val norm: Float = a.reduce(Math.abs(_) + Math.abs(_))
+ diff / norm
+ }
+}
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala
index 8716bb181e04..4da8e6728e2f 100644
--- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala
@@ -1,10 +1,11 @@
package ml.dmlc.mxnet
import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import ml.dmlc.mxnet.CheckUtils._
class ExecutorSuite extends FunSuite with BeforeAndAfterAll {
test("bind") {
- val shape = Array(100, 30)
+ val shape = Vector(100, 30)
val lhs = Symbol.Variable("lhs")
val rhs = Symbol.Variable("rhs")
val ret = lhs + rhs
@@ -39,10 +40,4 @@ class ExecutorSuite extends FunSuite with BeforeAndAfterAll {
assert(reldiff(lhsGrad, lhsGrad2) < 1e-6)
assert(reldiff(rhsGrad, rhsGrad2) < 1e-6)
}
-
- private def reldiff(a: NDArray, b: NDArray): Float = {
- val diff = NDArray.sum(NDArray.abs(a - b)).toScalar
- val norm = NDArray.sum(NDArray.abs(a)).toScalar
- diff / norm
- }
}
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
index 2812623b2d02..3e9fc625b783 100644
--- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala
@@ -5,7 +5,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite}
class KVStoreSuite extends FunSuite with BeforeAndAfterAll {
test("init and pull") {
val kv = KVStore.create()
- val shape = Array(2, 1)
+ val shape = Vector(2, 1)
val ndArray = NDArray.zeros(shape)
kv.init(3, NDArray.ones(shape))
@@ -15,7 +15,7 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll {
test("push and pull") {
val kv = KVStore.create()
- val shape = Array(2, 1)
+ val shape = Vector(2, 1)
val ndArray = NDArray.zeros(shape)
kv.init(3, NDArray.ones(shape))
@@ -36,7 +36,7 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll {
}
kv.setUpdater(updater)
- val shape = Array(2, 1)
+ val shape = Vector(2, 1)
val ndArray = NDArray.zeros(shape)
kv.init(3, NDArray.ones(shape) * 4)
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala
index d06ab134ece3..eacc8d6b3206 100644
--- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala
@@ -4,10 +4,9 @@ import java.io.File
import java.util.concurrent.atomic.AtomicInteger
import ml.dmlc.mxnet.NDArrayConversions._
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
-import org.scalactic.Tolerance._
+import org.scalatest.{Matchers, BeforeAndAfterAll, FunSuite}
-class NDArraySuite extends FunSuite with BeforeAndAfterAll {
+class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
private val sequence: AtomicInteger = new AtomicInteger(0)
test("to java array") {
@@ -104,7 +103,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll {
}
test("rsqrt") {
- val ndarray = NDArray.array(Array(1f, 4f), shape = Array(2, 1))
+ val ndarray = NDArray.array(Array(1f, 4f), shape = Vector(2, 1))
assert(NDArray.rsqrt(ndarray).toArray === Array(1f, 0.5f))
}
@@ -117,7 +116,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll {
}
test("one hot encode") {
- val indices = NDArray.array(Array(1f, 0f, 2f), shape = Array(3))
+ val indices = NDArray.array(Array(1f, 0f, 2f), shape = Vector(3))
val array = NDArray.empty(3, 3)
NDArray.onehotEncode(indices, array)
assert(array.shape === Array(3, 3))
@@ -127,22 +126,22 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll {
}
test("dot") {
- val arr1 = NDArray.array(Array(1f, 2f), shape = Array(1, 2))
- val arr2 = NDArray.array(Array(3f, 4f), shape = Array(2, 1))
+ val arr1 = NDArray.array(Array(1f, 2f), shape = Vector(1, 2))
+ val arr2 = NDArray.array(Array(3f, 4f), shape = Vector(2, 1))
val res = NDArray.dot(arr1, arr2)
assert(res.shape === Array(1, 1))
assert(res.toArray === Array(11f))
}
test("choose_element_0index") {
- val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 6f, 5f), shape = Array(2, 3))
- val indices = NDArray.array(Array(0f, 1f), shape = Array(2))
+ val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 6f, 5f), shape = Vector(2, 3))
+ val indices = NDArray.array(Array(0f, 1f), shape = Vector(2))
val res = NDArray.chooseElement0Index(arr, indices)
assert(res.toArray === Array(1f, 6f))
}
test("copy to") {
- val source = NDArray.array(Array(1f, 2f, 3f), shape = Array(1, 3))
+ val source = NDArray.array(Array(1f, 2f, 3f), shape = Vector(1, 3))
val dest = NDArray.empty(1, 3)
source.copyTo(dest)
assert(dest.shape === Array(1, 3))
@@ -173,32 +172,32 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll {
}
test("abs") {
- val arr = NDArray.array(Array(-1f, -2f, 3f), shape = Array(3, 1))
+ val arr = NDArray.array(Array(-1f, -2f, 3f), shape = Vector(3, 1))
assert(NDArray.abs(arr).toArray === Array(1f, 2f, 3f))
}
test("sign") {
- val arr = NDArray.array(Array(-1f, -2f, 3f), shape = Array(3, 1))
+ val arr = NDArray.array(Array(-1f, -2f, 3f), shape = Vector(3, 1))
assert(NDArray.sign(arr).toArray === Array(-1f, -1f, 1f))
}
test("round") {
- val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Array(3, 1))
+ val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Vector(3, 1))
assert(NDArray.round(arr).toArray === Array(2f, 2f, 4f))
}
test("ceil") {
- val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Array(3, 1))
+ val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Vector(3, 1))
assert(NDArray.ceil(arr).toArray === Array(2f, 3f, 4f))
}
test("floor") {
- val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Array(3, 1))
+ val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Vector(3, 1))
assert(NDArray.floor(arr).toArray === Array(1f, 2f, 3f))
}
test("square") {
- val arr = NDArray.array(Array(1f, 2f, 3f), shape = Array(3, 1))
+ val arr = NDArray.array(Array(1f, 2f, 3f), shape = Vector(3, 1))
assert(NDArray.square(arr).toArray === Array(1f, 4f, 9f))
}
@@ -226,30 +225,30 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll {
}
test("max") {
- val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Array(3, 1))
+ val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Vector(3, 1))
assert(NDArray.max(arr).toScalar === 3.7f +- 1e-3f)
}
test("min") {
- val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Array(3, 1))
+ val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Vector(3, 1))
assert(NDArray.min(arr).toScalar === 1.5f +- 1e-3f)
}
test("sum") {
- val arr = NDArray.array(Array(1f, 2f, 3f, 4f), shape = Array(2, 2))
+ val arr = NDArray.array(Array(1f, 2f, 3f, 4f), shape = Vector(2, 2))
assert(NDArray.sum(arr).toScalar === 10f +- 1e-3f)
}
test("argmaxChannel") {
- val arr = NDArray.array(Array(1f, 2f, 4f, 3f), shape = Array(2, 2))
+ val arr = NDArray.array(Array(1f, 2f, 4f, 3f), shape = Vector(2, 2))
val argmax = NDArray.argmaxChannel(arr)
assert(argmax.shape === Array(2))
assert(argmax.toArray === Array(1f, 0f))
}
test("concatenate") {
- val arr1 = NDArray.array(Array(1f, 2f, 4f, 3f, 3f, 3f), shape = Array(2, 3))
- val arr2 = NDArray.array(Array(8f, 7f, 6f), shape = Array(1, 3))
+ val arr1 = NDArray.array(Array(1f, 2f, 4f, 3f, 3f, 3f), shape = Vector(2, 3))
+ val arr2 = NDArray.array(Array(8f, 7f, 6f), shape = Vector(1, 3))
val arr = NDArray.concatenate(arr1, arr2)
assert(arr.shape === Array(3, 3))
assert(arr.toArray === Array(1f, 2f, 4f, 3f, 3f, 3f, 8f, 7f, 6f))
@@ -259,7 +258,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll {
val filename
= s"${System.getProperty("java.io.tmpdir")}/ndarray-${sequence.getAndIncrement}.bin"
try {
- val ndarray = NDArray.array(Array(1f, 2f, 3f), shape = Array(3, 1))
+ val ndarray = NDArray.array(Array(1f, 2f, 3f), shape = Vector(3, 1))
NDArray.save(filename, Map("local" -> ndarray))
val (keys, arrays) = NDArray.load(filename)
assert(keys.length === 1)
@@ -278,7 +277,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll {
val filename
= s"${System.getProperty("java.io.tmpdir")}/ndarray-${sequence.getAndIncrement}.bin"
try {
- val ndarray = NDArray.array(Array(1f, 2f, 3f), shape = Array(3, 1))
+ val ndarray = NDArray.array(Array(1f, 2f, 3f), shape = Vector(3, 1))
NDArray.save(filename, Array(ndarray))
val (keys, arrays) = NDArray.load(filename)
assert(keys.length === 0)
@@ -298,4 +297,26 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll {
assert(ctx.deviceType === "cpu")
assert(ctx.deviceId === 0)
}
+
+ test("equals") {
+ val ndarray1 = NDArray.array(Array(1f, 2f, 3f), shape = Vector(3, 1))
+ val ndarray2 = NDArray.array(Array(1f, 2f, 3f), shape = Vector(3, 1))
+ val ndarray3 = NDArray.array(Array(1f, 2f, 3f), shape = Vector(1, 3))
+ val ndarray4 = NDArray.array(Array(3f, 2f, 3f), shape = Vector(3, 1))
+ ndarray1 shouldEqual ndarray2
+ ndarray1 shouldNot equal(ndarray3)
+ ndarray1 shouldNot equal(ndarray4)
+ }
+
+ test("slice") {
+ val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Vector(3, 2))
+
+ val arr1 = arr.slice(1)
+ assert(arr1.shape === Vector(1, 2))
+ assert(arr1.toArray === Array(3f, 4f))
+
+ val arr2 = arr.slice(1, 3)
+ assert(arr2.shape === Vector(2, 2))
+ assert(arr2.toArray === Array(3f, 4f, 5f, 6f))
+ }
}
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala
new file mode 100644
index 000000000000..8e99553af3b8
--- /dev/null
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala
@@ -0,0 +1,174 @@
+package ml.dmlc.mxnet
+
+import ml.dmlc.mxnet.Base.Shape
+import ml.dmlc.mxnet.CheckUtils._
+import org.scalatest.prop.GeneratorDrivenPropertyChecks
+import org.scalatest.{Matchers, BeforeAndAfterAll, FunSuite}
+import org.scalacheck.Gen
+import scala.collection.mutable
+
+class OperatorSuite extends FunSuite with BeforeAndAfterAll
+ with Matchers with GeneratorDrivenPropertyChecks {
+ private def checkElementwiseSumWithShape(shape: Shape, n: Int) = {
+ // forward
+ val inputs = (0 until n).map(i => Symbol.Variable(s"arg $i"))
+ val out = Symbol.ElementWiseSum("esum", inputs: _*)
+ val arr = (0 until n).map(_ => Random.uniform(-10, 10, shape))
+ val arrGrad = (0 until n).map(_ => NDArray.empty(shape))
+ val exec = out.bind(Context.cpu(), args = arr, argsGrad = arrGrad)
+ exec.forward()
+ val forwardOutput = exec.outputs(0)
+ val forwardOutputExpected = arr.reduce(_ + _)
+ assert(reldiff(forwardOutput, forwardOutputExpected) < 1e-6)
+
+ // backward
+ val outGrad = Random.uniform(-10, 10, shape)
+ exec.backward(outGrad)
+ arrGrad.foreach(grad => assert(grad === outGrad))
+ }
+
+ test("elementwise sum") {
+ checkElementwiseSumWithShape(Vector(5, 5, 3), 4)
+ forAll (Gen.choose(1, 4), Gen.choose(1, 8)) { (dim, n) =>
+ forAll (Gen.listOfN(dim, Gen.choose(1, Math.pow(1000, 1.0 / dim).toInt))) { shape =>
+ checkElementwiseSumWithShape(shape.toVector, n)
+ }
+ }
+ }
+
+ // TODO: checkSliceChannel
+
+ private def checkConcatWithShape(shapes: Seq[Shape], dimension: Int, skipSecond: Boolean) = {
+ // if skipSecond is true, second argument will not have gradient.
+ // it is to test #1130
+ // forward
+ val targetDim = shapes.map(_(dimension)).sum
+
+ val inputs = (0 until shapes.size).map(i => Symbol.Variable(s"arg$i"))
+ val out = Symbol.Concat(inputs, Map("name" -> "conc", "dim" -> dimension))
+ val arr = shapes.map { shape =>
+ val nd = NDArray.empty(shape)
+ nd.set(shape(dimension))
+ }
+ val arrNp = arr.map(_.copy())
+ val arrGrad = shapes.map(NDArray.empty(_))
+ val argNames = out.listArguments()
+ val dictGrad =
+ (argNames zip arrGrad).filter { case (name, d) =>
+ !skipSecond || name != "arg1"
+ }.toMap
+
+ val args = out.listArguments()
+ val (argShapes, outShapes, auxShapes) = out.inferShape(args.zip(shapes).toMap)
+ val outGrad = NDArray.empty(outShapes(0))
+ val exec1 = out.bind(Context.cpu(), arr, dictGrad)
+ exec1.forward()
+ val out1 = exec1.outputs(0)
+ // FIXME: only support concatenate at axis0
+ val ret = NDArray.concatenate(arr)
+ assert(out1 === ret)
+
+ // backward
+ out1.copyTo(outGrad)
+ outGrad += 1
+ exec1.backward(outGrad)
+ argNames.zipWithIndex.foreach { case (name, i) =>
+ if (!skipSecond || name != "arg1") {
+ val grad = dictGrad(name)
+ val npGrad = arrNp(i)
+ assert(grad === npGrad + 1)
+ }
+ }
+ }
+
+ test("concat") {
+ val merge = Array(2, 3, 4, 5, 6)
+ forAll (Gen.choose(2, 5)) { dim =>
+ val shapes = mutable.ArrayBuffer.empty[Vector[Int]]
+ for (i <- 0 until dim) {
+ shapes += Vector(merge(i), 2)
+ }
+ // TODO: check dimension > 0
+ checkConcatWithShape(shapes, 0, skipSecond = true)
+ checkConcatWithShape(shapes, 0, skipSecond = false)
+ }
+ }
+
+ private def checkRegression(model: Symbol,
+ forward: Float => Float,
+ backward: (Float, Float) => Float) = {
+ val shape = Vector(3, 1)
+ val arrData = Random.uniform(-1, 1, shape)
+ val arrLabel = Random.uniform(0, 1, Vector(shape.head))
+ val arrGrad = NDArray.empty(shape)
+ val exec1 = model.bind(Context.cpu(),
+ args = Array(arrData, arrLabel), argsGrad = Map("data" -> arrGrad))
+ exec1.forward()
+ assert(exec1.outputs(0).shape === shape)
+ val out1 = exec1.outputs(0).toArray
+ val npout = arrData.toArray.map(forward(_))
+ assert(CheckUtils.reldiff(npout, out1) < 1e-6f)
+
+ exec1.backward()
+ // arrData shape: Vector(3, 1)
+ // arrLabel shape: Vector(3)
+ val npoutBack = (npout zip arrLabel.toArray).map { case (data, label) =>
+ backward(data, label)
+ }
+ assert(CheckUtils.reldiff(npoutBack, arrGrad.toArray) < 1e-6f)
+ }
+
+ test("regression") {
+ checkRegression(Symbol.LogisticRegressionOutput(
+ Array(Symbol.Variable("data"), Symbol.Variable("label"))),
+ (x: Float) => 1.0f / (1.0f + Math.exp(-x).toFloat),
+ (x: Float, y: Float) => x - y)
+ checkRegression(Symbol.LinearRegressionOutput(
+ Array(Symbol.Variable("data"), Symbol.Variable("label"))),
+ (x: Float) => x,
+ (x: Float, y: Float) => x - y)
+ }
+
+ // TODO: test softmax
+
+ test("swap axes") {
+ val data = Symbol.Variable("data")
+ val shape = Vector(2, 3, 4)
+ val arrData = NDArray.ones(shape)
+ arrData.slice(0).set(1f)
+ arrData.slice(1).set(2f)
+ // arrData =
+ //
+ // [[[ 1., 1., 1., 1.],
+ // [ 1., 1., 1., 1.],
+ // [ 1., 1., 1., 1.]],
+ //
+ // [[ 2., 2., 2., 2.],
+ // [ 2., 2., 2., 2.],
+ // [ 2., 2., 2., 2.]]]
+ val swap0 = Symbol.SwapAxis(data = data, dim1 = 0, dim2 = 2)
+ val swap = Symbol.SwapAxis(data = swap0, dim1 = 1, dim2 = 2)
+ val exec = swap.bind(Context.cpu(), args = Array(arrData))
+ exec.forward()
+ val out = exec.outputs(0)
+
+ // After swapaxes(swapaxes(arrData, 0, 2), 1, 2)
+ // out should be
+ // [[[ 1., 1., 1.],
+ // [ 2., 2., 2.]],
+ //
+ // [[ 1., 1., 1.],
+ // [ 2., 2., 2.]],
+ //
+ // [[ 1., 1., 1.],
+ // [ 2., 2., 2.]],
+ //
+ // [[ 1., 1., 1.],
+ // [ 2., 2., 2.]]]
+ assert(out.shape === Vector(4, 2, 3))
+ for (i <- 0 until 4) {
+ val axis0 = out.slice(i)
+ assert(CheckUtils.reldiff(axis0.toArray, Array(1f, 1f, 1f, 2f, 2f, 2f)) < 1e-6f)
+ }
+ }
+}
diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/RandomSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/RandomSuite.scala
new file mode 100644
index 000000000000..b1913b5e10d7
--- /dev/null
+++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/RandomSuite.scala
@@ -0,0 +1,36 @@
+package ml.dmlc.mxnet
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+class RandomSuite extends FunSuite with BeforeAndAfterAll {
+ test("uniform on cpu") {
+ Context.withScope(Context.cpu()) {
+ val (a, b) = (-10, 10)
+ val shape = Vector(100, 100)
+ Random.seed(128)
+ val un1 = Random.uniform(a, b, shape)
+ Random.seed(128)
+ val un2 = Random.uniform(a, b, shape)
+ assert(un1 === un2)
+ assert(Math.abs(un1.toArray.sum / un1.size - (a + b) / 2f) < 0.1)
+ }
+ }
+
+ test("normal on cpu") {
+ val (mu, sigma) = (10f, 2f)
+ val shape = Vector(100, 100)
+ Random.seed(128)
+ val ret1 = Random.normal(mu, sigma, shape)
+ Random.seed(128)
+ val ret2 = Random.normal(mu, sigma, shape)
+ assert(ret1 === ret2)
+
+ val array = ret1.toArray
+ val mean = array.sum / ret1.size
+ val devs = array.map(score => (score - mean) * (score - mean))
+ val stddev = Math.sqrt(devs.sum / ret1.size)
+
+ assert(Math.abs(mean - mu) < 0.1)
+ assert(Math.abs(stddev - sigma) < 0.1)
+ }
+}
diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
index 311ec87ab265..df5907b85566 100644
--- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc
@@ -1142,3 +1142,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorBindX
setLongField(env, jexecOut, (long) out);
return ret;
}
+
+JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxRandomSeed
+ (JNIEnv *env, jobject obj, jint seed) {
+ return MXRandomSeed(seed);
+}
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index 3e2a9a2494be..28b7b44a9ed9 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -233,6 +233,12 @@
2.2.4
test
+
+ org.scalacheck
+ scalacheck_${scala.binary.version}
+ 1.11.3
+ test
+