From dca971eb943bf55264faf96315426c5655d7372f Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 2 Oct 2018 11:26:26 -0700 Subject: [PATCH] fix bug and add tests --- .../main/scala/org/apache/mxnet/NDArray.scala | 2 +- .../scala/org/apache/mxnet/NDArraySuite.scala | 29 ++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 3ffca90b8166..61d653dd405d 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -449,7 +449,7 @@ object NDArray extends NDArrayBase { case arrAny : Array[Any] => { val fragment = (end - start + 1) / arrAny.length for (i <- arrAny.indices) - arrayCombiner(arrAny(i), arr, start + i * fragment, end + (i + 1) * fragment) + arrayCombiner(arrAny(i), arr, start + i * fragment, start + (i + 1) * fragment) } case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}") } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index 5d88bb39e502..15d19401582d 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -21,7 +21,9 @@ import java.io.File import java.util.concurrent.atomic.AtomicInteger import org.apache.mxnet.NDArrayConversions._ -import org.scalatest.{Matchers, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} + +import scala.collection.mutable.ArrayBuffer class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { private val sequence: AtomicInteger = new AtomicInteger(0) @@ -65,6 +67,31 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { assert(ndarray.toArray === Array(1f, 2f, 3f, 4f)) } + test("create NDArray based on Java Matrix") { + val arrBuf = ArrayBuffer[Array[Float]]() + for (i <- 0 until 100) arrBuf += Array(1.0f, 1.0f, 1.0f, 1.0f) + val arr = Array( + Array( + arrBuf.toArray + ), + Array( + arrBuf.toArray + ) + ) + var nd = NDArray.toNDArray(arr) + require(nd.shape == Shape(2, 1, 100, 4)) + val arr2 = Array(1.0f, 1.0f, 1.0f, 1.0f) + nd = NDArray.toNDArray(arr2) + require(nd.shape == Shape(4)) + } + + test("test Visualize") { + var nd = NDArray.ones(Shape(1, 2, 100, 1)) + nd.visualize + nd = NDArray.ones(Shape(1, 4)) + nd.visualize + } + test("plus") { val ndzeros = NDArray.zeros(2, 1) val ndones = ndzeros + 1f