Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add Visualize Util and migrate visualize structure to there
Browse files Browse the repository at this point in the history
  • Loading branch information
lanking520 committed Oct 24, 2018
1 parent ebc14c4 commit f59e098
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 56 deletions.
57 changes: 1 addition & 56 deletions scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,6 @@ object NDArray extends NDArrayBase {
"_onehot_encode", Seq(indices, out), Map("out" -> out))(0)
}

/**
* Get the String representation of NDArray
* @param nd input NDArray
* @return String
*/
def toString(nd : NDArray) : String = {
nd.visualize
}

/**
* Create an empty uninitialized new NDArray, with specified shape.
*
Expand Down Expand Up @@ -704,53 +695,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
}

override def toString() : String = {
s"${this.visualize}<NDArray ${this.shape} ${this.context}>"
}

/**
* Visualize the internal structure of NDArray
* @return String that show the structure
*/
def visualize: String = {
buildStringHelper(this, this.shape.length) + "\n"
}
/**
* Helper function to create formatted NDArray output
* The NDArray will be represented in a reduced version if too large
* @param nd NDArray as the input
* @param totalSpace totalSpace of the lowest dimension
* @return String format of NDArray
*/
private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = {
var result = ""
val THRESHOLD = 100000 // longest NDArray to show in full
val ARRAYTHRESHOLD = 1000 // longest array to show in full
val shape = nd.shape
val space = totalSpace - shape.length
if (shape.length != 1) {
val (length, postfix) =
if (shape.product > THRESHOLD) {
// reduced NDArray
(1, s"\n${" " * (space + 1)}... with length ${shape(0)}\n")
} else {
(shape(0), "")
}
for (num <- 0 until length) {
val output = buildStringHelper(nd.at(num), totalSpace)
result += s"$output\n"
}
result = s"${" " * space}[\n$result${" " * space}$postfix]"
} else {
if (shape(0) > ARRAYTHRESHOLD) {
// reduced Array
val front = nd.slice(0, 10)
val back = nd.slice(shape(0) - 10, shape(0) - 1)
result = s"${" " * space}[${front.toArray.mkString(",")} ... ${back.toArray.mkString(",")}]"
} else {
result = s"${" " * space}[${nd.toArray.mkString(",")}]"
}
}
result
s"<NDArray ${this.shape} ${this.context}>"
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.util

import org.apache.mxnet.{NDArray, Shape}

/**
* A visualize helper class to see the internal structure
* of mxnet data-structure
*/
object Visualize {

/**
* Visualize the internal structure of NDArray
* @return String that show the structure
*/
def toString(nd : NDArray): String = {
buildStringHelper(nd, nd.shape.length) + "\n"
}
/**
* Helper function to create formatted NDArray output
* The NDArray will be represented in a reduced version if too large
* @param nd NDArray as the input
* @param totalSpace totalSpace of the lowest dimension
* @return String format of NDArray
*/
private def buildStringHelper(nd : NDArray, totalSpace : Int) : String = {
var result = ""
val THRESHOLD = 100000 // longest NDArray to show in full
val ARRAYTHRESHOLD = 1000 // longest array to show in full
val shape = nd.shape
val space = totalSpace - shape.length
if (shape.length != 1) {
val (length, postfix) =
if (shape.product > THRESHOLD) {
// reduced NDArray
(1, s"\n${" " * (space + 1)}... with length ${shape(0)}\n")
} else {
(shape(0), "")
}
for (num <- 0 until length) {
val output = buildStringHelper(nd.at(num), totalSpace)
result += s"$output\n"
}
result = s"${" " * space}[\n$result${" " * space}$postfix]"
} else {
if (shape(0) > ARRAYTHRESHOLD) {
// reduced Array
val front = nd.slice(0, 10)
val back = nd.slice(shape(0) - 10, shape(0) - 1)
result = s"${" " * space}[${front.toArray.mkString(",")} ... ${back.toArray.mkString(",")}]"
} else {
result = s"${" " * space}[${nd.toArray.mkString(",")}]"
}
}
result
}
}

0 comments on commit f59e098

Please sign in to comment.