Skip to content
This repository has been archived by the owner on Feb 9, 2021. It is now read-only.

Commit

Permalink
[MXNET-1000] get Ndarray real value and form it from a NDArray (apach…
Browse files Browse the repository at this point in the history
…e#12690)

* add visualize

* adding Any type input to form NDArray

* fix bug and add tests

* add a toString method

* add Visualize Util and migrate visualize structure to there

* update with tests

* refactor code

* fix the minor issue

* add multiple types support

* add changes on names and tests

* make code elegant and improve readability
  • Loading branch information
lanking520 authored and Gordon Reid committed Jan 27, 2019
1 parent 895037f commit 21bffb0
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,10 @@ object MX_PRIMITIVES {

implicit def MX_DoubleToDouble(d: MX_Double) : Double = d.data

def isValidMxPrimitiveType(num : Any) : Boolean = {
num match {
case valid @ (_: Float | _: Double) => true
case _ => false
}
}
}
112 changes: 111 additions & 1 deletion scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.language.implicitConversions
import scala.ref.WeakReference
import scala.util.Try

/**
* NDArray Object extends from NDArrayBase for abstract function signatures
Expand Down Expand Up @@ -509,6 +510,61 @@ object NDArray extends NDArrayBase {
array(sourceArr, shape, null)
}

/**
* Create a new NDArray based on the structure of source Array
* @param sourceArr Array[Array...Array[MX_PRIMITIVE_TYPE]...]
* @param ctx context like to pass in
* @return an NDArray with the same shape of the input
* @throws IllegalArgumentException if the data type is not valid
*/
def toNDArray(sourceArr: Array[_], ctx : Context = null) : NDArray = {
val shape = shapeGetter(sourceArr)
val container = new Array[Any](shape.product)
flattenArray(sourceArr, container, 0, container.length - 1)
val finalArr = container(0) match {
case f: Float => array(container.map(_.asInstanceOf[Float]), Shape(shape), ctx)
case d: Double => array(container.map(_.asInstanceOf[Double]), Shape(shape), ctx)
case _ => throw new IllegalArgumentException(
s"Unsupported type ${container(0).getClass}, please check MX_PRIMITIVES for valid types")
}
finalArr
}

private def shapeGetter(sourceArr : Any) : ArrayBuffer[Int] = {
sourceArr match {
// e.g : Array[Double] the inner layer
case arr: Array[_] if MX_PRIMITIVES.isValidMxPrimitiveType(arr(0)) => {
ArrayBuffer[Int](arr.length)
}
// e.g : Array[Array...[]]
case arr: Array[_] => {
var arrBuffer = new ArrayBuffer[Int]()
if (!arr.isEmpty) arrBuffer = shapeGetter(arr(0))
for (idx <- arr.indices) {
require(arrBuffer == shapeGetter(arr(idx)))
}
arrBuffer.insert(0, arr.length)
arrBuffer
}
case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}")
}
}

private def flattenArray(sourceArr : Any, arr : Array[Any],
start : Int, end : Int) : Unit = {
sourceArr match {
case arrValid: Array[_] if MX_PRIMITIVES.isValidMxPrimitiveType(arrValid(0)) => {
for (i <- arrValid.indices) arr(start + i) = arrValid(i)
}
case arrAny: Array[_] => {
val fragment = (end - start + 1) / arrAny.length
for (i <- arrAny.indices)
flattenArray(arrAny(i), arr, start + i * fragment, start + (i + 1) * fragment)
}
case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}")
}
}

/**
* Returns evenly spaced values within a given interval.
* Values are generated within the half-open interval [`start`, `stop`). In other
Expand Down Expand Up @@ -667,7 +723,6 @@ object NDArray extends NDArrayBase {
genericNDArrayFunctionInvoke("_crop_assign", args, kwargs)
}

// TODO: imdecode
}

/**
Expand All @@ -694,6 +749,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
// we use weak reference to prevent gc blocking
private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]]

private val lengthProperty = "mxnet.setNDArrayPrintLength"
private val layerProperty = "mxnet.setNDArrayPrintLayerLength"
private lazy val printLength = Try(System.getProperty(lengthProperty).toInt).getOrElse(1000)
private lazy val layerLength = Try(System.getProperty(layerProperty).toInt).getOrElse(10)

def serialize(): Array[Byte] = {
val buf = ArrayBuffer.empty[Byte]
checkCall(_LIB.mxNDArraySaveRawBytes(handle, buf))
Expand Down Expand Up @@ -763,6 +823,56 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
checkCall(_LIB.mxFloat64NDArraySyncCopyFromCPU(handle, source, source.length))
}

/**
* Visualize the internal structure of NDArray
* @return String that show the structure
*/
override def toString: String = {
val abstractND = buildStringHelper(this, this.shape.length)
val otherInfo = s"<NDArray ${this.shape} ${this.context} ${this.dtype}>"
s"$abstractND\n$otherInfo"
}

/**
* Helper function to create formatted NDArray output
* The NDArray will be represented in a reduced version if too large
* @param nd NDArray as the input
* @param totalSpace totalSpace of the lowest dimension
* @return String format of NDArray
*/
private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = {
var result = ""
val THRESHOLD = layerLength // longest NDArray[NDArray[...]] to show in full
val ARRAYTHRESHOLD = printLength // longest array to show in full
val shape = nd.shape
val space = totalSpace - shape.length
if (shape.length != 1) {
val (length, postfix) =
if (shape(0) > THRESHOLD) {
// reduced NDArray
(10, s"\n${" " * (space + 1)}... with length ${shape(0)}\n")
} else {
(shape(0), "")
}
for (num <- 0 until length) {
val output = buildStringHelper(nd.at(num), totalSpace)
result += s"$output\n"
}
result = s"${" " * space}[\n$result${" " * space}$postfix${" " * space}]"
} else {
if (shape(0) > ARRAYTHRESHOLD) {
// reduced Array
val front = nd.slice(0, 10)
val back = nd.slice(shape(0) - 10, shape(0) - 1)
result = s"""${" " * space}[${front.toArray.mkString(",")}
| ... ${back.toArray.mkString(",")}]""".stripMargin
} else {
result = s"${" " * space}[${nd.toArray.mkString(",")}]"
}
}
result
}

/**
* Return a sliced NDArray that shares memory with current one.
* NDArray only support continuous slicing on axis 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ import java.util.concurrent.atomic.AtomicInteger

import org.apache.mxnet.NDArrayConversions._
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
import org.slf4j.LoggerFactory
import scala.collection.mutable.ArrayBuffer

class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
private val sequence: AtomicInteger = new AtomicInteger(0)

private val logger = LoggerFactory.getLogger(classOf[NDArraySuite])

test("to java array") {
val ndarray = NDArray.zeros(2, 2)
assert(ndarray.toArray === Array(0f, 0f, 0f, 0f))
Expand Down Expand Up @@ -85,6 +89,84 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
assert(ndarray.toArray === Array(1f, 2f, 3f, 4f))
}

test("create NDArray based on Java Matrix") {
def arrayGen(num : Any) : Array[Any] = {
val array = num match {
case f: Float =>
(for (_ <- 0 until 100) yield Array(1.0f, 1.0f, 1.0f, 1.0f)).toArray
case d: Double =>
(for (_ <- 0 until 100) yield Array(1.0d, 1.0d, 1.0d, 1.0d)).toArray
case _ => throw new IllegalArgumentException(s"Unsupported Type ${num.getClass}")
}
Array(
Array(
array
),
Array(
array
)
)
}
val floatData = 1.0f
var nd = NDArray.toNDArray(arrayGen(floatData))
require(nd.shape == Shape(2, 1, 100, 4))
val arr2 = Array(1.0f, 1.0f, 1.0f, 1.0f)
nd = NDArray.toNDArray(arr2)
require(nd.shape == Shape(4))
val doubleData = 1.0d
nd = NDArray.toNDArray(arrayGen(doubleData))
require(nd.shape == Shape(2, 1, 100, 4))
require(nd.dtype == DType.Float64)
}

test("test Visualize") {
var nd = NDArray.ones(Shape(1, 2, 1000, 1))
var data : String =
"""
|[
| [
| [
| [1.0]
| [1.0]
| [1.0]
| [1.0]
| [1.0]
| [1.0]
| [1.0]
| [1.0]
| [1.0]
| [1.0]
|
| ... with length 1000
| ]
| [
| [1.0]
| [1.0]
| [1.0]
| [1.0]
| [1.0]
| [1.0]
| [1.0]
| [1.0]
| [1.0]
| [1.0]
|
| ... with length 1000
| ]
| ]
|]
|<NDArray (1,2,1000,1) cpu(0) float32>""".stripMargin
require(nd.toString.split("\\s+").mkString == data.split("\\s+").mkString)
nd = NDArray.ones(Shape(1, 4))
data =
"""
|[
| [1.0,1.0,1.0,1.0]
|]
|<NDArray (1,4) cpu(0) float32>""".stripMargin
require(nd.toString.split("\\s+").mkString == data.split("\\s+").mkString)
}

test("plus") {
var ndzeros = NDArray.zeros(2, 1)
var ndones = ndzeros + 1f
Expand Down

0 comments on commit 21bffb0

Please sign in to comment.