Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add Sparse NDIter test
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Jun 30, 2019
1 parent 8e34661 commit 8ae0c9f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ class SparseNDArray private[mxnet] (override private[mxnet] val handle: NDArrayH
dense.at(idx)
}

override def slice(start: Int, end: Int): NDArray = {
printf(s"\n\nSlice being called!!\n\nstart:$start end:$end")
NDArray.api.slice(this, Shape(start), Shape(end))
}

/**
* Get the Data portion from the row Sparse NDArray
* @return NDArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.mxnet

import org.apache.mxnet.io.NDArrayIter
import org.scalatest.FunSuite
import org.slf4j.LoggerFactory

Expand Down Expand Up @@ -62,7 +63,6 @@ class SparseNDArraySuite extends FunSuite {
)
val indices = Array(0f, 1f, 3f)
val rspIn = SparseNDArray.rowSparseArray(arr, indices, Shape(4, 2), Context.cpu())
printf(rspIn.toString)
val toRetain = Array(0f, 3f)
val rspOut = SparseNDArray.retain(rspIn, toRetain)
assert(rspOut.getData.toArray sameElements Array(1f, 2f, 5f, 6f))
Expand All @@ -76,5 +76,18 @@ class SparseNDArraySuite extends FunSuite {
assert(nd2.toArray sameElements Array(2f, 4f, 6f))
}

test("Test DataIter") {
val nd = NDArray.array(Array(1f, 2f, 3f), Shape(1, 3)).toSparse(Some(SparseFormat.CSR))
val arr = IndexedSeq(nd, nd, nd, nd)
val iter = new NDArrayIter(arr)
while (iter.hasNext) {
val tempArr = iter.next().data
tempArr.foreach(ele => {
assert(ele.sparseFormat == SparseFormat.CSR)
assert(ele.shape == Shape(1, 3))
})
}
}


}

0 comments on commit 8ae0c9f

Please sign in to comment.