Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix bug and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Oct 2, 2018
1 parent a9b5855 commit dca971e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ object NDArray extends NDArrayBase {
case arrAny : Array[Any] => {
val fragment = (end - start + 1) / arrAny.length
for (i <- arrAny.indices)
arrayCombiner(arrAny(i), arr, start + i * fragment, end + (i + 1) * fragment)
arrayCombiner(arrAny(i), arr, start + i * fragment, start + (i + 1) * fragment)
}
case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import java.io.File
import java.util.concurrent.atomic.AtomicInteger

import org.apache.mxnet.NDArrayConversions._
import org.scalatest.{Matchers, BeforeAndAfterAll, FunSuite}
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}

import scala.collection.mutable.ArrayBuffer

class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
private val sequence: AtomicInteger = new AtomicInteger(0)
Expand Down Expand Up @@ -65,6 +67,31 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
assert(ndarray.toArray === Array(1f, 2f, 3f, 4f))
}

test("create NDArray based on Java Matrix") {
val arrBuf = ArrayBuffer[Array[Float]]()
for (i <- 0 until 100) arrBuf += Array(1.0f, 1.0f, 1.0f, 1.0f)
val arr = Array(
Array(
arrBuf.toArray
),
Array(
arrBuf.toArray
)
)
var nd = NDArray.toNDArray(arr)
require(nd.shape == Shape(2, 1, 100, 4))
val arr2 = Array(1.0f, 1.0f, 1.0f, 1.0f)
nd = NDArray.toNDArray(arr2)
require(nd.shape == Shape(4))
}

test("test Visualize") {
var nd = NDArray.ones(Shape(1, 2, 100, 1))
nd.visualize
nd = NDArray.ones(Shape(1, 4))
nd.visualize
}

test("plus") {
val ndzeros = NDArray.zeros(2, 1)
val ndones = ndzeros + 1f
Expand Down

0 comments on commit dca971e

Please sign in to comment.