From f59e098b3fd6cd4ca5b417e47f2c86d36792cf51 Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 23 Oct 2018 18:10:11 -0700 Subject: [PATCH] add Visualize Util and migrate visualize structure to there --- .../main/scala/org/apache/mxnet/NDArray.scala | 57 +-------------- .../org/apache/mxnet/util/Visualize.scala | 73 +++++++++++++++++++ 2 files changed, 74 insertions(+), 56 deletions(-) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/util/Visualize.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 40774db772fc..74e899361594 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -197,15 +197,6 @@ object NDArray extends NDArrayBase { "_onehot_encode", Seq(indices, out), Map("out" -> out))(0) } - /** - * Get the String representation of NDArray - * @param nd input NDArray - * @return String - */ - def toString(nd : NDArray) : String = { - nd.visualize - } - /** * Create an empty uninitialized new NDArray, with specified shape. * @@ -704,53 +695,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, } override def toString() : String = { - s"${this.visualize}" - } - - /** - * Visualize the internal structure of NDArray - * @return String that show the structure - */ - def visualize: String = { - buildStringHelper(this, this.shape.length) + "\n" - } - /** - * Helper function to create formatted NDArray output - * The NDArray will be represented in a reduced version if too large - * @param nd NDArray as the input - * @param totalSpace totalSpace of the lowest dimension - * @return String format of NDArray - */ - private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = { - var result = "" - val THRESHOLD = 100000 // longest NDArray to show in full - val ARRAYTHRESHOLD = 1000 // longest array to show in full - val shape = nd.shape - val space = totalSpace - shape.length - if (shape.length != 1) { - val (length, postfix) = - if (shape.product > THRESHOLD) { - // reduced NDArray - (1, s"\n${" " * (space + 1)}... with length ${shape(0)}\n") - } else { - (shape(0), "") - } - for (num <- 0 until length) { - val output = buildStringHelper(nd.at(num), totalSpace) - result += s"$output\n" - } - result = s"${" " * space}[\n$result${" " * space}$postfix]" - } else { - if (shape(0) > ARRAYTHRESHOLD) { - // reduced Array - val front = nd.slice(0, 10) - val back = nd.slice(shape(0) - 10, shape(0) - 1) - result = s"${" " * space}[${front.toArray.mkString(",")} ... ${back.toArray.mkString(",")}]" - } else { - result = s"${" " * space}[${nd.toArray.mkString(",")}]" - } - } - result + s"" } /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/util/Visualize.scala b/scala-package/core/src/main/scala/org/apache/mxnet/util/Visualize.scala new file mode 100644 index 000000000000..0cafc12a7f8b --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/util/Visualize.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.util + +import org.apache.mxnet.{NDArray, Shape} + +/** + * A visualize helper class to see the internal structure + * of mxnet data-structure + */ +object Visualize { + + /** + * Visualize the internal structure of NDArray + * @return String that show the structure + */ + def toString(nd : NDArray): String = { + buildStringHelper(nd, nd.shape.length) + "\n" + } + /** + * Helper function to create formatted NDArray output + * The NDArray will be represented in a reduced version if too large + * @param nd NDArray as the input + * @param totalSpace totalSpace of the lowest dimension + * @return String format of NDArray + */ + private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = { + var result = "" + val THRESHOLD = 100000 // longest NDArray to show in full + val ARRAYTHRESHOLD = 1000 // longest array to show in full + val shape = nd.shape + val space = totalSpace - shape.length + if (shape.length != 1) { + val (length, postfix) = + if (shape.product > THRESHOLD) { + // reduced NDArray + (1, s"\n${" " * (space + 1)}... with length ${shape(0)}\n") + } else { + (shape(0), "") + } + for (num <- 0 until length) { + val output = buildStringHelper(nd.at(num), totalSpace) + result += s"$output\n" + } + result = s"${" " * space}[\n$result${" " * space}$postfix]" + } else { + if (shape(0) > ARRAYTHRESHOLD) { + // reduced Array + val front = nd.slice(0, 10) + val back = nd.slice(shape(0) - 10, shape(0) - 1) + result = s"${" " * space}[${front.toArray.mkString(",")} ... ${back.toArray.mkString(",")}]" + } else { + result = s"${" " * space}[${nd.toArray.mkString(",")}]" + } + } + result + } +}