Skip to content

Commit

Permalink
Faster Scala NDArray to BufferedImage function (apache#16219)
Browse files Browse the repository at this point in the history
  • Loading branch information
zachgk authored and larroy committed Sep 28, 2019
1 parent 1f62dd0 commit 531ae68
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,18 @@ object Image {
def toImage(src: NDArray): BufferedImage = {
require(src.dtype == DType.UInt8, "The input NDArray must be bytes")
require(src.shape.length == 3, "The input should contains height, width and channel")
require(src.shape(2) == 3, "There should be three channels: RGB")
val height = src.shape.get(0)
val width = src.shape.get(1)
val img = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB)
val arr = src.toArray
(0 until height).par.foreach(r => {
(0 until width).par.foreach(c => {
val arr = src.at(r).at(c).toArray
// NDArray in RGB
val red = arr(0).toByte & 0xFF
val green = arr(1).toByte & 0xFF
val blue = arr(2).toByte & 0xFF
val cellIndex = r * width * 3 + c * 3
val red = arr(cellIndex).toByte & 0xFF
val green = arr(cellIndex + 1).toByte & 0xFF
val blue = arr(cellIndex + 2).toByte & 0xFF
val rgb = (red << 16) | (green << 8) | blue
img.setRGB(c, r, rgb)
})
Expand Down

0 comments on commit 531ae68

Please sign in to comment.