From cff4a5bbb48f0ceec22a7ba56d16eee0954f9161 Mon Sep 17 00:00:00 2001 From: Lanking Date: Fri, 25 Jan 2019 12:01:02 -0800 Subject: [PATCH] [MXNET-1000] get Ndarray real value and form it from a NDArray (#12690) * add visualize * adding Any type input to form NDArray * fix bug and add tests * add a toString method * add Visualize Util and migrate visualize structure to there * update with tests * refactor code * fix the minor issue * add multiple types support * add changes on names and tests * make code elegant and improve readability --- .../org/apache/mxnet/MX_PRIMITIVES.scala | 6 + .../main/scala/org/apache/mxnet/NDArray.scala | 112 +++++++++++++++++- .../scala/org/apache/mxnet/NDArraySuite.scala | 82 +++++++++++++ 3 files changed, 199 insertions(+), 1 deletion(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala index cb978856963c..3a51222cc0b8 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala @@ -82,4 +82,10 @@ object MX_PRIMITIVES { implicit def MX_DoubleToDouble(d: MX_Double) : Double = d.data + def isValidMxPrimitiveType(num : Any) : Boolean = { + num match { + case valid @ (_: Float | _: Double) => true + case _ => false + } + } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 163ed2682532..5c345f21faf4 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -28,6 +28,7 @@ import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.language.implicitConversions import scala.ref.WeakReference +import scala.util.Try /** * NDArray Object extends from NDArrayBase for abstract function signatures @@ -509,6 +510,61 @@ object NDArray extends NDArrayBase { array(sourceArr, shape, null) } + /** + * Create a new NDArray based on the structure of source Array + * @param sourceArr Array[Array...Array[MX_PRIMITIVE_TYPE]...] + * @param ctx context like to pass in + * @return an NDArray with the same shape of the input + * @throws IllegalArgumentException if the data type is not valid + */ + def toNDArray(sourceArr: Array[_], ctx : Context = null) : NDArray = { + val shape = shapeGetter(sourceArr) + val container = new Array[Any](shape.product) + flattenArray(sourceArr, container, 0, container.length - 1) + val finalArr = container(0) match { + case f: Float => array(container.map(_.asInstanceOf[Float]), Shape(shape), ctx) + case d: Double => array(container.map(_.asInstanceOf[Double]), Shape(shape), ctx) + case _ => throw new IllegalArgumentException( + s"Unsupported type ${container(0).getClass}, please check MX_PRIMITIVES for valid types") + } + finalArr + } + + private def shapeGetter(sourceArr : Any) : ArrayBuffer[Int] = { + sourceArr match { + // e.g : Array[Double] the inner layer + case arr: Array[_] if MX_PRIMITIVES.isValidMxPrimitiveType(arr(0)) => { + ArrayBuffer[Int](arr.length) + } + // e.g : Array[Array...[]] + case arr: Array[_] => { + var arrBuffer = new ArrayBuffer[Int]() + if (!arr.isEmpty) arrBuffer = shapeGetter(arr(0)) + for (idx <- arr.indices) { + require(arrBuffer == shapeGetter(arr(idx))) + } + arrBuffer.insert(0, arr.length) + arrBuffer + } + case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}") + } + } + + private def flattenArray(sourceArr : Any, arr : Array[Any], + start : Int, end : Int) : Unit = { + sourceArr match { + case arrValid: Array[_] if MX_PRIMITIVES.isValidMxPrimitiveType(arrValid(0)) => { + for (i <- arrValid.indices) arr(start + i) = arrValid(i) + } + case arrAny: Array[_] => { + val fragment = (end - start + 1) / arrAny.length + for (i <- arrAny.indices) + flattenArray(arrAny(i), arr, start + i * fragment, start + (i + 1) * fragment) + } + case _ => throw new IllegalArgumentException(s"Wrong type passed: ${sourceArr.getClass}") + } + } + /** * Returns evenly spaced values within a given interval. * Values are generated within the half-open interval [`start`, `stop`). In other @@ -667,7 +723,6 @@ object NDArray extends NDArrayBase { genericNDArrayFunctionInvoke("_crop_assign", args, kwargs) } - // TODO: imdecode } /** @@ -694,6 +749,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, // we use weak reference to prevent gc blocking private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]] + private val lengthProperty = "mxnet.setNDArrayPrintLength" + private val layerProperty = "mxnet.setNDArrayPrintLayerLength" + private lazy val printLength = Try(System.getProperty(lengthProperty).toInt).getOrElse(1000) + private lazy val layerLength = Try(System.getProperty(layerProperty).toInt).getOrElse(10) + def serialize(): Array[Byte] = { val buf = ArrayBuffer.empty[Byte] checkCall(_LIB.mxNDArraySaveRawBytes(handle, buf)) @@ -763,6 +823,56 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, checkCall(_LIB.mxFloat64NDArraySyncCopyFromCPU(handle, source, source.length)) } + /** + * Visualize the internal structure of NDArray + * @return String that show the structure + */ + override def toString: String = { + val abstractND = buildStringHelper(this, this.shape.length) + val otherInfo = s"" + s"$abstractND\n$otherInfo" + } + + /** + * Helper function to create formatted NDArray output + * The NDArray will be represented in a reduced version if too large + * @param nd NDArray as the input + * @param totalSpace totalSpace of the lowest dimension + * @return String format of NDArray + */ + private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = { + var result = "" + val THRESHOLD = layerLength // longest NDArray[NDArray[...]] to show in full + val ARRAYTHRESHOLD = printLength // longest array to show in full + val shape = nd.shape + val space = totalSpace - shape.length + if (shape.length != 1) { + val (length, postfix) = + if (shape(0) > THRESHOLD) { + // reduced NDArray + (10, s"\n${" " * (space + 1)}... with length ${shape(0)}\n") + } else { + (shape(0), "") + } + for (num <- 0 until length) { + val output = buildStringHelper(nd.at(num), totalSpace) + result += s"$output\n" + } + result = s"${" " * space}[\n$result${" " * space}$postfix${" " * space}]" + } else { + if (shape(0) > ARRAYTHRESHOLD) { + // reduced Array + val front = nd.slice(0, 10) + val back = nd.slice(shape(0) - 10, shape(0) - 1) + result = s"""${" " * space}[${front.toArray.mkString(",")} + | ... ${back.toArray.mkString(",")}]""".stripMargin + } else { + result = s"${" " * space}[${nd.toArray.mkString(",")}]" + } + } + result + } + /** * Return a sliced NDArray that shares memory with current one. * NDArray only support continuous slicing on axis 0 diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index bc7a0a026bc3..054300e952a8 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -22,10 +22,14 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.mxnet.NDArrayConversions._ import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.slf4j.LoggerFactory +import scala.collection.mutable.ArrayBuffer class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { private val sequence: AtomicInteger = new AtomicInteger(0) + private val logger = LoggerFactory.getLogger(classOf[NDArraySuite]) + test("to java array") { val ndarray = NDArray.zeros(2, 2) assert(ndarray.toArray === Array(0f, 0f, 0f, 0f)) @@ -85,6 +89,84 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { assert(ndarray.toArray === Array(1f, 2f, 3f, 4f)) } + test("create NDArray based on Java Matrix") { + def arrayGen(num : Any) : Array[Any] = { + val array = num match { + case f: Float => + (for (_ <- 0 until 100) yield Array(1.0f, 1.0f, 1.0f, 1.0f)).toArray + case d: Double => + (for (_ <- 0 until 100) yield Array(1.0d, 1.0d, 1.0d, 1.0d)).toArray + case _ => throw new IllegalArgumentException(s"Unsupported Type ${num.getClass}") + } + Array( + Array( + array + ), + Array( + array + ) + ) + } + val floatData = 1.0f + var nd = NDArray.toNDArray(arrayGen(floatData)) + require(nd.shape == Shape(2, 1, 100, 4)) + val arr2 = Array(1.0f, 1.0f, 1.0f, 1.0f) + nd = NDArray.toNDArray(arr2) + require(nd.shape == Shape(4)) + val doubleData = 1.0d + nd = NDArray.toNDArray(arrayGen(doubleData)) + require(nd.shape == Shape(2, 1, 100, 4)) + require(nd.dtype == DType.Float64) + } + + test("test Visualize") { + var nd = NDArray.ones(Shape(1, 2, 1000, 1)) + var data : String = + """ + |[ + | [ + | [ + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | + | ... with length 1000 + | ] + | [ + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | [1.0] + | + | ... with length 1000 + | ] + | ] + |] + |""".stripMargin + require(nd.toString.split("\\s+").mkString == data.split("\\s+").mkString) + nd = NDArray.ones(Shape(1, 4)) + data = + """ + |[ + | [1.0,1.0,1.0,1.0] + |] + |""".stripMargin + require(nd.toString.split("\\s+").mkString == data.split("\\s+").mkString) + } + test("plus") { var ndzeros = NDArray.zeros(2, 1) var ndones = ndzeros + 1f