From 04ea7978f2c97487b4d7707b9872f74e13666b29 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 3 Feb 2016 00:08:31 +0800 Subject: [PATCH 1/4] Change NDArray shape's type from Array to Vector, add Random unit test cases --- .../src/main/scala/ml/dmlc/mxnet/Base.scala | 2 +- .../main/scala/ml/dmlc/mxnet/Context.scala | 14 ++++++- .../main/scala/ml/dmlc/mxnet/Executor.scala | 4 +- .../main/scala/ml/dmlc/mxnet/LibInfo.scala | 9 ++-- .../main/scala/ml/dmlc/mxnet/NDArray.scala | 39 ++++++++--------- .../src/main/scala/ml/dmlc/mxnet/Random.scala | 5 +-- .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 8 ++-- .../scala/ml/dmlc/mxnet/ExecutorSuite.scala | 2 +- .../scala/ml/dmlc/mxnet/KVStoreSuite.scala | 6 +-- .../scala/ml/dmlc/mxnet/NDArraySuite.scala | 42 +++++++++---------- .../scala/ml/dmlc/mxnet/RandomSuite.scala | 36 ++++++++++++++++ .../main/native/ml_dmlc_mxnet_native_c_api.cc | 5 +++ 12 files changed, 112 insertions(+), 60 deletions(-) create mode 100644 scala-package/core/src/test/scala/ml/dmlc/mxnet/RandomSuite.scala diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala index 206b8d13f108..76b5d1a3df17 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Base.scala @@ -11,7 +11,7 @@ object Base { type MXFloat = Float type CPtrAddress = Long // TODO: make it more friendly to java - type Shape = Array[Int] + type Shape = Vector[Int] type NDArrayHandle = CPtrAddress type FunctionHandle = CPtrAddress diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala index cdacfa176719..3f9b205fa0f7 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala @@ -3,7 +3,9 @@ package ml.dmlc.mxnet object Context { val devtype2str = Map(1 -> "cpu", 2 -> "gpu", 3 -> "cpu_pinned") val devstr2type = Map("cpu" -> 1, "gpu" -> 2, "cpu_pinned" -> 3) - val defaultCtx = new Context("cpu", 0) + private var _defaultCtx = new Context("cpu", 0) + + def defaultCtx = _defaultCtx def cpu(deviceId: Int = 0): Context = { new Context("cpu", deviceId) @@ -13,6 +15,16 @@ object Context { new Context("gpu", deviceId) } + def withScope[T](device: Context)(body: => T): T = { + val oldDefaultCtx = Context.defaultCtx + Context._defaultCtx = device + try { + body + } finally { + Context._defaultCtx = oldDefaultCtx + } + } + implicit def ctx2Array(ctx: Context): Array[Context] = Array(ctx) } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala index 1715aae5549f..72d5ad3623ea 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Executor.scala @@ -294,7 +294,7 @@ class DataParallelExecutorManager(symbol: Symbol, ctx.zipWithIndex.map { case (context, i) => val dataShapes = trainData.provideData.map { case (name: String, shape: Shape) => - (name, Array(slices(i)._2 - slices(i)._1) ++ shape.drop(1)) + (name, Vector(slices(i)._2 - slices(i)._1) ++ shape.drop(1)) } symbol.simpleBind(context, "write", shapeDict = dataShapes) } @@ -334,7 +334,7 @@ class DataParallelExecutorManager(symbol: Symbol, }.toArray private val batchSize = trainData.batchSize private val outputShapes: Array[Shape] = trainExecs(0).outputs.map { x: NDArray => - Array(batchSize) ++ x.shape.drop(1) + Vector(batchSize) ++ x.shape.drop(1) } private[mxnet] val cpuOutputArrays = outputShapes.map(NDArray.zeros(_)) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala index 96d80f1d4ca7..c49f8ce9d1a8 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/LibInfo.scala @@ -164,9 +164,9 @@ class LibInfo { keys: Array[String], argIndPtr: Array[MXUint], argShapeData: Array[MXUint], - inShapeData: ListBuffer[Shape], - outShapeData: ListBuffer[Shape], - auxShapeData: ListBuffer[Shape], + inShapeData: ListBuffer[Array[Int]], + outShapeData: ListBuffer[Array[Int]], + auxShapeData: ListBuffer[Array[Int]], complete: RefInt): Int @native def mxSymbolGetOutput(handle: SymbolHandle, index: Int, out: SymbolHandleRef): Int @native def mxSymbolSaveToJSON(handle: SymbolHandle, out: RefString): Int @@ -185,4 +185,7 @@ class LibInfo { auxArgsHandle: Array[NDArrayHandle], out: ExecutorHandleRef): Int // scalastyle:on parameterNum + + // Random + @native def mxRandomSeed(seed: Int): Int } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index 20682b93cb4c..33843979e0a6 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -125,12 +125,12 @@ object NDArray { * * @return a new empty ndarray handle */ - private def newAllocHandle(shape: Array[Int], + private def newAllocHandle(shape: Shape, ctx: Context, delayAlloc: Boolean): NDArrayHandle = { val hdl = new NDArrayHandleRef checkCall(_LIB.mxNDArrayCreate( - shape, + shape.toArray, shape.length, ctx.deviceTypeid, ctx.deviceId, @@ -217,14 +217,14 @@ object NDArray { * * @return The created NDArray. */ - def empty(shape: Array[Int], ctx: Context = null): NDArray = { + def empty(shape: Shape, ctx: Context = null): NDArray = { val context = if (ctx == null) Context.defaultCtx else ctx new NDArray(handle = NDArray.newAllocHandle(shape, context, delayAlloc = false)) } - def empty(shape: Int *): NDArray = empty(shape.toArray) + def empty(shape: Int *): NDArray = empty(shape.toVector) - def empty(ctx: Context, shape: Int *): NDArray = empty(shape.toArray, ctx) + def empty(ctx: Context, shape: Int *): NDArray = empty(shape.toVector, ctx) /** * Create a new NDArray filled with 0, with specified shape. @@ -234,15 +234,15 @@ object NDArray { * * @return The created NDArray. */ - def zeros(shape: Array[Int], ctx: Context = null): NDArray = { + def zeros(shape: Shape, ctx: Context = null): NDArray = { val arr = empty(shape, ctx) arr.set(0f) arr } - def zeros(shape: Int *): NDArray = zeros(shape.toArray) + def zeros(shape: Int *): NDArray = zeros(shape.toVector) - def zeros(ctx: Context, shape: Int *): NDArray = zeros(shape.toArray, ctx) + def zeros(ctx: Context, shape: Int *): NDArray = zeros(shape.toVector, ctx) /** * Create a new NDArray filled with 1, with specified shape. @@ -250,15 +250,15 @@ object NDArray { * @param ctx The context of the NDArray, default to current default context. * @return The created NDArray. */ - def ones(shape: Array[Int], ctx: Context = null): NDArray = { + def ones(shape: Shape, ctx: Context = null): NDArray = { val arr = empty(shape, ctx) arr.set(1f) arr } - def ones(shape: Int *): NDArray = ones(shape.toArray) + def ones(shape: Int *): NDArray = ones(shape.toVector) - def ones(ctx: Context, shape: Int *): NDArray = ones(shape.toArray, ctx) + def ones(ctx: Context, shape: Int *): NDArray = ones(shape.toVector, ctx) /** * Clip ndarray elements to range (from, to) @@ -460,7 +460,7 @@ object NDArray { * @param ctx The context of the NDArray, default to current default context. * @return The created NDArray. */ - def array(sourceArr: Array[Float], shape: Array[Int], ctx: Context = null): NDArray = { + def array(sourceArr: Array[Float], shape: Shape, ctx: Context = null): NDArray = { val arr = empty(shape, ctx) arr.set(sourceArr) arr @@ -474,18 +474,15 @@ object NDArray { def concatenate(arrays: Seq[NDArray], ctx: Context = null): NDArray = { require(arrays != null && arrays.size > 0, "arrays empty") val array0 = arrays.head - val shape = Array.ofDim[Int](array0.shape.length) - array0.shape.copyToArray(shape) - var axis0 = shape(0) - val shapeRemain = shape.drop(1) + val shape = array0.shape.drop(1) + var axis0 = array0.shape(0) arrays.drop(1).foreach { array => - require(shapeRemain.sameElements(array.shape.drop(1)), + require(shape.sameElements(array.shape.drop(1)), s"shape mismatch between (${array.shape.mkString(",")}) and (${shape.mkString(",")})") axis0 += array.shape(0) } - shape(0) = axis0 - val output = NDArray.empty(shape, ctx) + val output = NDArray.empty(Vector(axis0) ++ shape, ctx) axis0 = 0 arrays.foreach { array => output.slice(axis0, axis0 + array.shape(0)).set(array) @@ -804,12 +801,12 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = * Get shape of current NDArray. * @return an array representing shape of current ndarray */ - def shape: Array[Int] = { + def shape: Shape = { val ndim = new MXUintRef val data = ArrayBuffer[Int]() checkCall(_LIB.mxNDArrayGetShape(handle, ndim, data)) require(ndim.value == data.length, s"ndim=$ndim, while len(pdata)=${data.length}") - data.toArray + data.toVector } // Get size of current NDArray. diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala index d1554d29c5ac..e1279e095dfa 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala @@ -1,6 +1,6 @@ package ml.dmlc.mxnet -import ml.dmlc.mxnet.Base.Shape +import ml.dmlc.mxnet.Base._ import ml.dmlc.mxnet.NDArray.{randomGaussian, randomUniform, empty} /** @@ -73,7 +73,6 @@ object Random { * generated from GPU0 can be different from CPU. */ def seed(seedState: Int): Unit = { - // TODO -// checkCall(_LIB.mxRandomSeed(seedState)) + checkCall(_LIB.mxRandomSeed(seedState)) } } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index e58e6abaae39..50223765c7bf 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -202,15 +202,15 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { def inferShape(keys: Array[String], indPtr: Array[Int], values: Array[Int]) : (Seq[Shape], Seq[Shape], Seq[Shape]) = { - val argShapeData = ListBuffer.empty[Shape] - val outShapeData = ListBuffer.empty[Shape] - val auxShapeData = ListBuffer.empty[Shape] + val argShapeData = ListBuffer.empty[Array[Int]] + val outShapeData = ListBuffer.empty[Array[Int]] + val auxShapeData = ListBuffer.empty[Array[Int]] val complete = new RefInt checkCall(_LIB.mxSymbolInferShape(handle, indPtr.size - 1, keys, indPtr, values, argShapeData, outShapeData, auxShapeData, complete)) if (complete.value != 0) { - (argShapeData, outShapeData, auxShapeData) + (argShapeData.map(_.toVector), outShapeData.map(_.toVector), auxShapeData.map(_.toVector)) } else { (null, null, null) } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala index 8716bb181e04..7fbdb253d778 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala @@ -4,7 +4,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite} class ExecutorSuite extends FunSuite with BeforeAndAfterAll { test("bind") { - val shape = Array(100, 30) + val shape = Vector(100, 30) val lhs = Symbol.Variable("lhs") val rhs = Symbol.Variable("rhs") val ret = lhs + rhs diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala index 2812623b2d02..3e9fc625b783 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/KVStoreSuite.scala @@ -5,7 +5,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite} class KVStoreSuite extends FunSuite with BeforeAndAfterAll { test("init and pull") { val kv = KVStore.create() - val shape = Array(2, 1) + val shape = Vector(2, 1) val ndArray = NDArray.zeros(shape) kv.init(3, NDArray.ones(shape)) @@ -15,7 +15,7 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll { test("push and pull") { val kv = KVStore.create() - val shape = Array(2, 1) + val shape = Vector(2, 1) val ndArray = NDArray.zeros(shape) kv.init(3, NDArray.ones(shape)) @@ -36,7 +36,7 @@ class KVStoreSuite extends FunSuite with BeforeAndAfterAll { } kv.setUpdater(updater) - val shape = Array(2, 1) + val shape = Vector(2, 1) val ndArray = NDArray.zeros(shape) kv.init(3, NDArray.ones(shape) * 4) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala index d06ab134ece3..b743e2d616a3 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala @@ -104,7 +104,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { } test("rsqrt") { - val ndarray = NDArray.array(Array(1f, 4f), shape = Array(2, 1)) + val ndarray = NDArray.array(Array(1f, 4f), shape = Vector(2, 1)) assert(NDArray.rsqrt(ndarray).toArray === Array(1f, 0.5f)) } @@ -117,7 +117,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { } test("one hot encode") { - val indices = NDArray.array(Array(1f, 0f, 2f), shape = Array(3)) + val indices = NDArray.array(Array(1f, 0f, 2f), shape = Vector(3)) val array = NDArray.empty(3, 3) NDArray.onehotEncode(indices, array) assert(array.shape === Array(3, 3)) @@ -127,22 +127,22 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { } test("dot") { - val arr1 = NDArray.array(Array(1f, 2f), shape = Array(1, 2)) - val arr2 = NDArray.array(Array(3f, 4f), shape = Array(2, 1)) + val arr1 = NDArray.array(Array(1f, 2f), shape = Vector(1, 2)) + val arr2 = NDArray.array(Array(3f, 4f), shape = Vector(2, 1)) val res = NDArray.dot(arr1, arr2) assert(res.shape === Array(1, 1)) assert(res.toArray === Array(11f)) } test("choose_element_0index") { - val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 6f, 5f), shape = Array(2, 3)) - val indices = NDArray.array(Array(0f, 1f), shape = Array(2)) + val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 6f, 5f), shape = Vector(2, 3)) + val indices = NDArray.array(Array(0f, 1f), shape = Vector(2)) val res = NDArray.chooseElement0Index(arr, indices) assert(res.toArray === Array(1f, 6f)) } test("copy to") { - val source = NDArray.array(Array(1f, 2f, 3f), shape = Array(1, 3)) + val source = NDArray.array(Array(1f, 2f, 3f), shape = Vector(1, 3)) val dest = NDArray.empty(1, 3) source.copyTo(dest) assert(dest.shape === Array(1, 3)) @@ -173,32 +173,32 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { } test("abs") { - val arr = NDArray.array(Array(-1f, -2f, 3f), shape = Array(3, 1)) + val arr = NDArray.array(Array(-1f, -2f, 3f), shape = Vector(3, 1)) assert(NDArray.abs(arr).toArray === Array(1f, 2f, 3f)) } test("sign") { - val arr = NDArray.array(Array(-1f, -2f, 3f), shape = Array(3, 1)) + val arr = NDArray.array(Array(-1f, -2f, 3f), shape = Vector(3, 1)) assert(NDArray.sign(arr).toArray === Array(-1f, -1f, 1f)) } test("round") { - val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Array(3, 1)) + val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Vector(3, 1)) assert(NDArray.round(arr).toArray === Array(2f, 2f, 4f)) } test("ceil") { - val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Array(3, 1)) + val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Vector(3, 1)) assert(NDArray.ceil(arr).toArray === Array(2f, 3f, 4f)) } test("floor") { - val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Array(3, 1)) + val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Vector(3, 1)) assert(NDArray.floor(arr).toArray === Array(1f, 2f, 3f)) } test("square") { - val arr = NDArray.array(Array(1f, 2f, 3f), shape = Array(3, 1)) + val arr = NDArray.array(Array(1f, 2f, 3f), shape = Vector(3, 1)) assert(NDArray.square(arr).toArray === Array(1f, 4f, 9f)) } @@ -226,30 +226,30 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { } test("max") { - val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Array(3, 1)) + val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Vector(3, 1)) assert(NDArray.max(arr).toScalar === 3.7f +- 1e-3f) } test("min") { - val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Array(3, 1)) + val arr = NDArray.array(Array(1.5f, 2.1f, 3.7f), shape = Vector(3, 1)) assert(NDArray.min(arr).toScalar === 1.5f +- 1e-3f) } test("sum") { - val arr = NDArray.array(Array(1f, 2f, 3f, 4f), shape = Array(2, 2)) + val arr = NDArray.array(Array(1f, 2f, 3f, 4f), shape = Vector(2, 2)) assert(NDArray.sum(arr).toScalar === 10f +- 1e-3f) } test("argmaxChannel") { - val arr = NDArray.array(Array(1f, 2f, 4f, 3f), shape = Array(2, 2)) + val arr = NDArray.array(Array(1f, 2f, 4f, 3f), shape = Vector(2, 2)) val argmax = NDArray.argmaxChannel(arr) assert(argmax.shape === Array(2)) assert(argmax.toArray === Array(1f, 0f)) } test("concatenate") { - val arr1 = NDArray.array(Array(1f, 2f, 4f, 3f, 3f, 3f), shape = Array(2, 3)) - val arr2 = NDArray.array(Array(8f, 7f, 6f), shape = Array(1, 3)) + val arr1 = NDArray.array(Array(1f, 2f, 4f, 3f, 3f, 3f), shape = Vector(2, 3)) + val arr2 = NDArray.array(Array(8f, 7f, 6f), shape = Vector(1, 3)) val arr = NDArray.concatenate(arr1, arr2) assert(arr.shape === Array(3, 3)) assert(arr.toArray === Array(1f, 2f, 4f, 3f, 3f, 3f, 8f, 7f, 6f)) @@ -259,7 +259,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { val filename = s"${System.getProperty("java.io.tmpdir")}/ndarray-${sequence.getAndIncrement}.bin" try { - val ndarray = NDArray.array(Array(1f, 2f, 3f), shape = Array(3, 1)) + val ndarray = NDArray.array(Array(1f, 2f, 3f), shape = Vector(3, 1)) NDArray.save(filename, Map("local" -> ndarray)) val (keys, arrays) = NDArray.load(filename) assert(keys.length === 1) @@ -278,7 +278,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { val filename = s"${System.getProperty("java.io.tmpdir")}/ndarray-${sequence.getAndIncrement}.bin" try { - val ndarray = NDArray.array(Array(1f, 2f, 3f), shape = Array(3, 1)) + val ndarray = NDArray.array(Array(1f, 2f, 3f), shape = Vector(3, 1)) NDArray.save(filename, Array(ndarray)) val (keys, arrays) = NDArray.load(filename) assert(keys.length === 0) diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/RandomSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/RandomSuite.scala new file mode 100644 index 000000000000..8a4c1a2171c7 --- /dev/null +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/RandomSuite.scala @@ -0,0 +1,36 @@ +package ml.dmlc.mxnet + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +class RandomSuite extends FunSuite with BeforeAndAfterAll { + test("uniform on cpu") { + Context.withScope(Context.cpu()) { + val (a, b) = (-10, 10) + val shape = Vector(100, 100) + Random.seed(128) + val un1 = Random.uniform(a, b, shape) + Random.seed(128) + val un2 = Random.uniform(a, b, shape) + assert(un1.toArray === un2.toArray) + assert(Math.abs(un1.toArray.sum / un1.size - (a + b) / 2f) < 0.1) + } + } + + test("normal on cpu") { + val (mu, sigma) = (10f, 2f) + val shape = Vector(100, 100) + Random.seed(128) + val ret1 = Random.normal(mu, sigma, shape) + Random.seed(128) + val ret2 = Random.normal(mu, sigma, shape) + assert(ret1.toArray === ret2.toArray) + + val array = ret1.toArray + val mean = array.sum / ret1.size + val devs = array.map(score => (score - mean) * (score - mean)) + val stddev = Math.sqrt(devs.sum / ret1.size) + + assert(Math.abs(mean - mu) < 0.1) + assert(Math.abs(stddev - sigma) < 0.1) + } +} diff --git a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc index 311ec87ab265..df5907b85566 100644 --- a/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/ml_dmlc_mxnet_native_c_api.cc @@ -1142,3 +1142,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxExecutorBindX setLongField(env, jexecOut, (long) out); return ret; } + +JNIEXPORT jint JNICALL Java_ml_dmlc_mxnet_LibInfo_mxRandomSeed + (JNIEnv *env, jobject obj, jint seed) { + return MXRandomSeed(seed); +} From 0d683bb6aa8b351e37a3830e7b67a7f1808c00a5 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 6 Feb 2016 02:28:52 +0800 Subject: [PATCH 2/4] add NDArray equals, add unit tests for ElementWiseSum and Concat --- scala-package/core/pom.xml | 4 + .../main/scala/ml/dmlc/mxnet/Context.scala | 2 +- .../main/scala/ml/dmlc/mxnet/NDArray.scala | 19 +++- .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 19 ++++ .../test/scala/ml/dmlc/mxnet/CheckUtils.scala | 9 ++ .../scala/ml/dmlc/mxnet/ExecutorSuite.scala | 7 +- .../scala/ml/dmlc/mxnet/NDArraySuite.scala | 27 +++++- .../scala/ml/dmlc/mxnet/OperatorSuite.scala | 96 +++++++++++++++++++ .../scala/ml/dmlc/mxnet/RandomSuite.scala | 4 +- scala-package/pom.xml | 6 ++ 10 files changed, 178 insertions(+), 15 deletions(-) create mode 100644 scala-package/core/src/test/scala/ml/dmlc/mxnet/CheckUtils.scala create mode 100644 scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml index e65f9e761598..c0b2bfa696ea 100644 --- a/scala-package/core/pom.xml +++ b/scala-package/core/pom.xml @@ -85,5 +85,9 @@ org.scalatest scalatest_${scala.binary.version} + + org.scalacheck + scalacheck_${scala.binary.version} + diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala index 3f9b205fa0f7..08e390bf4f2c 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Context.scala @@ -5,7 +5,7 @@ object Context { val devstr2type = Map("cpu" -> 1, "gpu" -> 2, "cpu_pinned" -> 3) private var _defaultCtx = new Context("cpu", 0) - def defaultCtx = _defaultCtx + def defaultCtx: Context = _defaultCtx def cpu(deviceId: Int = 0): Context = { new Context("cpu", deviceId) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index 33843979e0a6..41ecdfe2d34b 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -467,7 +467,7 @@ object NDArray { } /** - * Join a sequence of arrays at the first axis + * Join a sequence of arrays at axis-0 * TODO: shall we make it native? * @param arrays */ @@ -601,8 +601,14 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = slice(range._1, range._2) } - def slice(start: Int): NDArray = { - slice(start, shape(0)) + /** + * Return a sliced NDArray at the ith position of axis0 + * NDArray only support continuous slicing on axis 0 + * @param i + * @return a sliced NDArray that shares memory with current one. + */ + def slice(i: Int): NDArray = { + slice(i, i + 1) } /** @@ -811,6 +817,13 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = // Get size of current NDArray. def size: Int = shape.product + + override def equals(o: Any): Boolean = o match { + case that: NDArray => { + that.shape == this.shape && that.toArray.sameElements(this.toArray) + } + case _ => false + } } // scalastyle:on finalize diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index 50223765c7bf..c0dfe1c09e2d 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -555,6 +555,10 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { bind(ctx, args, argsGrad, "write", Nil, null) } + def bind(ctx: Context, args: Seq[NDArray], argsGrad: Map[String, NDArray]): Executor = { + bind(ctx, args, argsGrad, "write", Nil, null) + } + def bind(ctx: Context, args: Seq[NDArray]): Executor = { val symbolArguments = listArguments() bindHelper(ctx, symbolArguments, args, null, @@ -732,6 +736,21 @@ object Symbol { createNoCheck("Cast") } + def ElementWiseSum(name: String, inputs: Symbol *): Symbol = { + create("ElementWiseSum", inputs.toArray, Map("name" -> name), null) + } + + def ElementWiseSum(inputs: Seq[Symbol], name: String): Symbol = { + create("ElementWiseSum", inputs.toArray, Map("name" -> name), null) + } + + def Concat(inputs: Seq[Symbol], + paramKwargs: Map[String, Any], + attr: Map[String, String] = null): Symbol = { + create("Concat", inputs.toArray, + paramKwargs.map { case (k, v) => (k, v.toString) }, attr) + } + /** * Create a symbol that groups symbols together. * @param symbols List of symbols to be grouped. diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/CheckUtils.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/CheckUtils.scala new file mode 100644 index 000000000000..d4e796678b52 --- /dev/null +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/CheckUtils.scala @@ -0,0 +1,9 @@ +package ml.dmlc.mxnet + +object CheckUtils { + def reldiff(a: NDArray, b: NDArray): Float = { + val diff = NDArray.sum(NDArray.abs(a - b)).toScalar + val norm = NDArray.sum(NDArray.abs(a)).toScalar + diff / norm + } +} diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala index 7fbdb253d778..4da8e6728e2f 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/ExecutorSuite.scala @@ -1,6 +1,7 @@ package ml.dmlc.mxnet import org.scalatest.{BeforeAndAfterAll, FunSuite} +import ml.dmlc.mxnet.CheckUtils._ class ExecutorSuite extends FunSuite with BeforeAndAfterAll { test("bind") { @@ -39,10 +40,4 @@ class ExecutorSuite extends FunSuite with BeforeAndAfterAll { assert(reldiff(lhsGrad, lhsGrad2) < 1e-6) assert(reldiff(rhsGrad, rhsGrad2) < 1e-6) } - - private def reldiff(a: NDArray, b: NDArray): Float = { - val diff = NDArray.sum(NDArray.abs(a - b)).toScalar - val norm = NDArray.sum(NDArray.abs(a)).toScalar - diff / norm - } } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala index b743e2d616a3..4b9c26b36669 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala @@ -4,10 +4,9 @@ import java.io.File import java.util.concurrent.atomic.AtomicInteger import ml.dmlc.mxnet.NDArrayConversions._ -import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.scalactic.Tolerance._ +import org.scalatest.{Matchers, BeforeAndAfterAll, FunSuite} -class NDArraySuite extends FunSuite with BeforeAndAfterAll { +class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { private val sequence: AtomicInteger = new AtomicInteger(0) test("to java array") { @@ -298,4 +297,26 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll { assert(ctx.deviceType === "cpu") assert(ctx.deviceId === 0) } + + test("equals") { + val ndarray1 = NDArray.array(Array(1f, 2f, 3f), shape = Vector(3, 1)) + val ndarray2 = NDArray.array(Array(1f, 2f, 3f), shape = Vector(3, 1)) + val ndarray3 = NDArray.array(Array(1f, 2f, 3f), shape = Vector(1, 3)) + val ndarray4 = NDArray.array(Array(3f, 2f, 3f), shape = Vector(3, 1)) + ndarray1 shouldEqual ndarray2 + ndarray1 shouldNot equal(ndarray3) + ndarray1 shouldNot equal(ndarray4) + } + + test("slice") { + val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Vector(3, 2)) + + val arr1 = arr.slice(1) + assert(arr1.shape === Vector(1, 2)) + assert(arr1.toArray === Array(3f, 4f)) + + val arr2 = arr.slice(1, 3) + assert(arr2.shape === Vector(2, 2)) + assert(arr2.toArray === Array(3f, 4f, 5f, 6f)) + } } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala new file mode 100644 index 000000000000..a8fecee3acb9 --- /dev/null +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala @@ -0,0 +1,96 @@ +package ml.dmlc.mxnet + +import ml.dmlc.mxnet.Base.Shape +import ml.dmlc.mxnet.CheckUtils._ +import org.scalatest.prop.GeneratorDrivenPropertyChecks +import org.scalatest.{Matchers, BeforeAndAfterAll, FunSuite} +import org.scalacheck.Gen +import scala.collection.mutable + +class OperatorSuite extends FunSuite with BeforeAndAfterAll + with Matchers with GeneratorDrivenPropertyChecks { + private def checkElementwiseSumWithShape(shape: Shape, n: Int) = { + // forward + val inputs = (0 until n).map(i => Symbol.Variable(s"arg $i")) + val out = Symbol.ElementWiseSum("esum", inputs: _*) + val arr = (0 until n).map(_ => Random.uniform(-10, 10, shape)) + val arrGrad = (0 until n).map(_ => NDArray.empty(shape)) + val exec = out.bind(Context.cpu(), args = arr, argsGrad = arrGrad) + exec.forward() + val forwardOutput = exec.outputs(0) + val forwardOutputExpected = arr.reduce(_ + _) + assert(reldiff(forwardOutput, forwardOutputExpected) < 1e-6) + + // backward + val outGrad = Random.uniform(-10, 10, shape) + exec.backward(outGrad) + arrGrad.foreach(grad => assert(grad === outGrad)) + } + + test("elementwise sum") { + checkElementwiseSumWithShape(Vector(5, 5, 3), 4) + forAll (Gen.choose(1, 4), Gen.choose(1, 8)) { (dim, n) => + forAll (Gen.listOfN(dim, Gen.choose(1, Math.pow(1000, 1.0 / dim).toInt))) { shape => + checkElementwiseSumWithShape(shape.toVector, n) + } + } + } + + // TODO: checkSliceChannel + + private def checkConcatWithShape(shapes: Seq[Shape], dimension: Int, skipSecond: Boolean) = { + // if skipSecond is true, second argument will not have gradient. + // it is to test #1130 + // forward + val targetDim = shapes.map(_(dimension)).sum + + val inputs = (0 until shapes.size).map(i => Symbol.Variable(s"arg$i")) + val out = Symbol.Concat(inputs, Map("name" -> "conc", "dim" -> dimension)) + val arr = shapes.map { shape => + val nd = NDArray.empty(shape) + nd.set(shape(dimension)) + } + val arrNp = arr.map(_.copy()) + val arrGrad = shapes.map(NDArray.empty(_)) + val argNames = out.listArguments() + val dictGrad = + (argNames zip arrGrad).filter { case (name, d) => + !skipSecond || name != "arg1" + }.toMap + + val args = out.listArguments() + val (argShapes, outShapes, auxShapes) = out.inferShape(args.zip(shapes).toMap) + val outGrad = NDArray.empty(outShapes(0)) + val exec1 = out.bind(Context.cpu(), arr, dictGrad) + exec1.forward() + val out1 = exec1.outputs(0) + // FIXME: only support concatenate at axis0 + val ret = NDArray.concatenate(arr) + assert(out1 === ret) + + // backward + out1.copyTo(outGrad) + outGrad += 1 + exec1.backward(outGrad) + argNames.zipWithIndex.foreach { case (name, i) => + if (!skipSecond || name != "arg1") { + val grad = dictGrad(name) + val npGrad = arrNp(i) + assert(grad === npGrad + 1) + } + } + } + + test("concat") { + val merge = Array(2, 3, 4, 5, 6) + forAll (Gen.choose(2, 5)) { dim => + val shapes = mutable.ArrayBuffer.empty[Vector[Int]] + for (i <- 0 until dim) { + shapes += Vector(merge(i), 2) + } + // TODO: check dimension > 0 + checkConcatWithShape(shapes, 0, skipSecond = true) + checkConcatWithShape(shapes, 0, skipSecond = false) + } + } +} diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/RandomSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/RandomSuite.scala index 8a4c1a2171c7..b1913b5e10d7 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/RandomSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/RandomSuite.scala @@ -11,7 +11,7 @@ class RandomSuite extends FunSuite with BeforeAndAfterAll { val un1 = Random.uniform(a, b, shape) Random.seed(128) val un2 = Random.uniform(a, b, shape) - assert(un1.toArray === un2.toArray) + assert(un1 === un2) assert(Math.abs(un1.toArray.sum / un1.size - (a + b) / 2f) < 0.1) } } @@ -23,7 +23,7 @@ class RandomSuite extends FunSuite with BeforeAndAfterAll { val ret1 = Random.normal(mu, sigma, shape) Random.seed(128) val ret2 = Random.normal(mu, sigma, shape) - assert(ret1.toArray === ret2.toArray) + assert(ret1 === ret2) val array = ret1.toArray val mean = array.sum / ret1.size diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 3e2a9a2494be..28b7b44a9ed9 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -233,6 +233,12 @@ 2.2.4 test + + org.scalacheck + scalacheck_${scala.binary.version} + 1.11.3 + test + From a0559bc8c6d0c246d68a2f7618532385f8b3aa14 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 6 Feb 2016 02:36:56 +0800 Subject: [PATCH 3/4] Implement NDArray hashCode --- .../core/src/main/scala/ml/dmlc/mxnet/NDArray.scala | 5 +++++ .../core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index 41ecdfe2d34b..8ad6d32248d3 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -824,6 +824,11 @@ class NDArray(private[mxnet] val handle: NDArrayHandle, val writable: Boolean = } case _ => false } + + override def hashCode: Int = { + // TODO: naive implementation + shape.hashCode + toArray.hashCode + } } // scalastyle:on finalize diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala index 4b9c26b36669..eacc8d6b3206 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/NDArraySuite.scala @@ -309,7 +309,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { } test("slice") { - val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Vector(3, 2)) + val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Vector(3, 2)) val arr1 = arr.slice(1) assert(arr1.shape === Vector(1, 2)) From 9cf857bb8d6fb886596a3cd2d08e0de52a030987 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 9 Feb 2016 01:05:25 -0800 Subject: [PATCH 4/4] add Symbol::SwapAxis, tested --- .../src/main/scala/ml/dmlc/mxnet/Model.scala | 1 + .../src/main/scala/ml/dmlc/mxnet/Symbol.scala | 33 ++++++-- .../test/scala/ml/dmlc/mxnet/CheckUtils.scala | 7 ++ .../scala/ml/dmlc/mxnet/OperatorSuite.scala | 78 +++++++++++++++++++ 4 files changed, 114 insertions(+), 5 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala index 1d4642f3cbf7..e99dc30c12f8 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Model.scala @@ -426,6 +426,7 @@ class FeedForward(val symbol: Symbol, val ctx: Array[Context] = Array(Context.cp } batch = data.next() } + // TODO: we can use Symbol.concat to do the same thing. Can it be more efficient? outputs.map(NDArray.concatenate(_)) } diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index c0dfe1c09e2d..43c382823955 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -674,6 +674,7 @@ class Symbol(private[mxnet] val handle: SymbolHandle) { } object Symbol { + private type SymbolCreateFunc = Map[String, Any] => Symbol private val logger = LoggerFactory.getLogger(classOf[Symbol]) private val functions: Map[String, SymbolFunction] = initSymbolModule() private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3) @@ -692,23 +693,23 @@ object Symbol { sym } - def FullyConnected: Map[String, Any] => Symbol = { + def FullyConnected: SymbolCreateFunc = { FullyConnected(null) } - def FullyConnected(attr: Map[String, String]): Map[String, Any] => Symbol = { + def FullyConnected(attr: Map[String, String]): SymbolCreateFunc = { createNoCheck("FullyConnected", attr) } - def Activation: Map[String, Any] => Symbol = { + def Activation: SymbolCreateFunc = { Activation(null) } - def Activation(attr: Map[String, String]): Map[String, Any] => Symbol = { + def Activation(attr: Map[String, String]): SymbolCreateFunc = { createNoCheck("Activation", attr) } - def Convolution(attr: Map[String, String]): Map[String, Any] => Symbol = { + def Convolution(attr: Map[String, String]): SymbolCreateFunc = { createNoCheck("Convolution", attr) } @@ -751,6 +752,28 @@ object Symbol { paramKwargs.map { case (k, v) => (k, v.toString) }, attr) } + // Use Logistic regression for final output, this is used on final output of a net. + // Logistic regression is suitable for binary classification or probability prediction tasks. + def LogisticRegressionOutput(inputs: Seq[Symbol], attr: Map[String, String] = null): Symbol = { + create("LogisticRegressionOutput", inputs.toArray, null, attr) + } + + // Use linear regression for final output, this is used on final output of a net. + def LinearRegressionOutput(inputs: Seq[Symbol], attr: Map[String, String] = null): Symbol = { + create("LinearRegressionOutput", inputs.toArray, null, attr) + } + + /** + * Apply swapaxis to input. + * @param data Input data to the SwapAxisOp. + * @param dim1 (non-negative), default=0, the first axis to be swapped. + * @param dim2 (non-negative), default=0, the second axis to be swapped. + */ + def SwapAxis(data: Symbol, dim1: Int = 0, dim2: Int = 0, + attr: Map[String, String] = null): Symbol = { + createNoCheck("SwapAxis")(Map("data" -> data, "dim1" -> dim1, "dim2" -> dim2)) + } + /** * Create a symbol that groups symbols together. * @param symbols List of symbols to be grouped. diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/CheckUtils.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/CheckUtils.scala index d4e796678b52..2fc2f02d09a2 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/CheckUtils.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/CheckUtils.scala @@ -6,4 +6,11 @@ object CheckUtils { val norm = NDArray.sum(NDArray.abs(a)).toScalar diff / norm } + + def reldiff(a: Array[Float], b: Array[Float]): Float = { + val diff = + (a zip b).map { case (aElem, bElem) => Math.abs(aElem - bElem) }.sum + val norm: Float = a.reduce(Math.abs(_) + Math.abs(_)) + diff / norm + } } diff --git a/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala index a8fecee3acb9..8e99553af3b8 100644 --- a/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala +++ b/scala-package/core/src/test/scala/ml/dmlc/mxnet/OperatorSuite.scala @@ -93,4 +93,82 @@ class OperatorSuite extends FunSuite with BeforeAndAfterAll checkConcatWithShape(shapes, 0, skipSecond = false) } } + + private def checkRegression(model: Symbol, + forward: Float => Float, + backward: (Float, Float) => Float) = { + val shape = Vector(3, 1) + val arrData = Random.uniform(-1, 1, shape) + val arrLabel = Random.uniform(0, 1, Vector(shape.head)) + val arrGrad = NDArray.empty(shape) + val exec1 = model.bind(Context.cpu(), + args = Array(arrData, arrLabel), argsGrad = Map("data" -> arrGrad)) + exec1.forward() + assert(exec1.outputs(0).shape === shape) + val out1 = exec1.outputs(0).toArray + val npout = arrData.toArray.map(forward(_)) + assert(CheckUtils.reldiff(npout, out1) < 1e-6f) + + exec1.backward() + // arrData shape: Vector(3, 1) + // arrLabel shape: Vector(3) + val npoutBack = (npout zip arrLabel.toArray).map { case (data, label) => + backward(data, label) + } + assert(CheckUtils.reldiff(npoutBack, arrGrad.toArray) < 1e-6f) + } + + test("regression") { + checkRegression(Symbol.LogisticRegressionOutput( + Array(Symbol.Variable("data"), Symbol.Variable("label"))), + (x: Float) => 1.0f / (1.0f + Math.exp(-x).toFloat), + (x: Float, y: Float) => x - y) + checkRegression(Symbol.LinearRegressionOutput( + Array(Symbol.Variable("data"), Symbol.Variable("label"))), + (x: Float) => x, + (x: Float, y: Float) => x - y) + } + + // TODO: test softmax + + test("swap axes") { + val data = Symbol.Variable("data") + val shape = Vector(2, 3, 4) + val arrData = NDArray.ones(shape) + arrData.slice(0).set(1f) + arrData.slice(1).set(2f) + // arrData = + // + // [[[ 1., 1., 1., 1.], + // [ 1., 1., 1., 1.], + // [ 1., 1., 1., 1.]], + // + // [[ 2., 2., 2., 2.], + // [ 2., 2., 2., 2.], + // [ 2., 2., 2., 2.]]] + val swap0 = Symbol.SwapAxis(data = data, dim1 = 0, dim2 = 2) + val swap = Symbol.SwapAxis(data = swap0, dim1 = 1, dim2 = 2) + val exec = swap.bind(Context.cpu(), args = Array(arrData)) + exec.forward() + val out = exec.outputs(0) + + // After swapaxes(swapaxes(arrData, 0, 2), 1, 2) + // out should be + // [[[ 1., 1., 1.], + // [ 2., 2., 2.]], + // + // [[ 1., 1., 1.], + // [ 2., 2., 2.]], + // + // [[ 1., 1., 1.], + // [ 2., 2., 2.]], + // + // [[ 1., 1., 1.], + // [ 2., 2., 2.]]] + assert(out.shape === Vector(4, 2, 3)) + for (i <- 0 until 4) { + val axis0 = out.slice(i) + assert(CheckUtils.reldiff(axis0.toArray, Array(1f, 1f, 1f, 2f, 2f, 2f)) < 1e-6f) + } + } }