Skip to content

Commit

Permalink
[MXNET-1180] Java Image API (apache#13807)
Browse files Browse the repository at this point in the history
* add java example

* add test and change PredictorExample

* add image change

* Add minor fixes

* add License

* add predictor Example tests

* fix the issue with JUnit test

* Satisfy Lint God ʕ •ᴥ•ʔ

* update the pom file config

* update documentation

* add simplified methods
  • Loading branch information
lanking520 authored and stephenrawls committed Feb 16, 2019
1 parent 87b6711 commit 9b8fbf4
Show file tree
Hide file tree
Showing 12 changed files with 274 additions and 103 deletions.
6 changes: 0 additions & 6 deletions scala-package/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,6 @@
<version>INTERNAL</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ object Image {
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param to_rgb Whether to convert decoded image
* to mxnet's default RGB format (instead of opencv's default BGR).
* @return NDArray in HWC format
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(buf: Array[Byte], flag: Int,
to_rgb: Boolean,
Expand All @@ -56,7 +56,7 @@ object Image {
/**
* Same imageDecode with InputStream
* @param inputStream the inputStream of the image
* @return NDArray in HWC format
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(inputStream: InputStream, flag: Int = 1,
to_rgb: Boolean = true,
Expand All @@ -78,7 +78,7 @@ object Image {
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param to_rgb Whether to convert decoded image to mxnet's default RGB format
* (instead of opencv's default BGR).
* @return org.apache.mxnet.NDArray in HWC format
* @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
*/
def imRead(filename: String, flag: Option[Int] = None,
to_rgb: Option[Boolean] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@ object NDArray extends NDArrayBase {
case ndArr: Seq[NDArray @unchecked] =>
if (ndArr.head.isInstanceOf[NDArray]) (ndArr.toArray, ndArr.toArray.map(_.handle))
else throw new IllegalArgumentException(
"Unsupported out var type, should be NDArray or subclass of Seq[NDArray]")
s"""Unsupported out ${output.getClass} type,
| should be NDArray or subclass of Seq[NDArray]""".stripMargin)
case _ => throw new IllegalArgumentException(
"Unsupported out var type, should be NDArray or subclass of Seq[NDArray]")
s"""Unsupported out ${output.getClass} type,
| should be NDArray or subclass of Seq[NDArray]""".stripMargin)
}
} else {
(null, null)
Expand Down
114 changes: 114 additions & 0 deletions scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.javaapi
// scalastyle:off
import java.awt.image.BufferedImage
// scalastyle:on
import java.io.InputStream

object Image {
/**
* Decode image with OpenCV.
* Note: return image in RGB by default, instead of OpenCV's default BGR.
* @param buf Buffer containing binary encoded image
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image
* to mxnet's default RGB format (instead of opencv's default BGR).
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = {
org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None)
}

def imDecode(buf: Array[Byte]): NDArray = {
imDecode(buf, 1, true)
}

/**
* Same imageDecode with InputStream
*
* @param inputStream the inputStream of the image
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(inputStream: InputStream, flag: Int, toRGB: Boolean): NDArray = {
org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None)
}

def imDecode(inputStream: InputStream): NDArray = {
imDecode(inputStream, 1, true)
}

/**
* Read and decode image with OpenCV.
* Note: return image in RGB by default, instead of OpenCV's default BGR.
* @param filename Name of the image file to be loaded.
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image to mxnet's default RGB format
* (instead of opencv's default BGR).
* @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
*/
def imRead(filename: String, flag: Int, toRGB: Boolean): NDArray = {
org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None)
}

def imRead(filename: String): NDArray = {
imRead(filename, 1, true)
}

/**
* Resize image with OpenCV.
* @param src source image in NDArray
* @param w Width of resized image.
* @param h Height of resized image.
* @param interp Interpolation method (default=cv2.INTER_LINEAR).
* @return org.apache.mxnet.NDArray
*/
def imResize(src: NDArray, w: Int, h: Int, interp: Integer): NDArray = {
val interpVal = if (interp == null) None else Some(interp.intValue())
org.apache.mxnet.Image.imResize(src, w, h, interpVal, None)
}

def imResize(src: NDArray, w: Int, h: Int): NDArray = {
imResize(src, w, h, null)
}

/**
* Do a fixed crop on the image
* @param src Src image in NDArray
* @param x0 starting x point
* @param y0 starting y point
* @param w width of the image
* @param h height of the image
* @return cropped NDArray
*/
def fixedCrop(src: NDArray, x0: Int, y0: Int, w: Int, h: Int): NDArray = {
org.apache.mxnet.Image.fixedCrop(src, x0, y0, w, h)
}

/**
* Convert a NDArray image to a real image
* The time cost will increase if the image resolution is big
* @param src Source image file in RGB
* @return Buffered Image
*/
def toImage(src: NDArray): BufferedImage = {
org.apache.mxnet.Image.toImage(src)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mxnet.javaapi;

import org.apache.commons.io.FileUtils;
import org.junit.BeforeClass;
import org.junit.Test;
import java.io.File;
import java.net.URL;

import static org.junit.Assert.assertArrayEquals;

public class ImageTest {

private static String imLocation;

private static void downloadUrl(String url, String filePath, int maxRetry) throws Exception{
File tmpFile = new File(filePath);
Boolean success = false;
if (!tmpFile.exists()) {
while (maxRetry > 0 && !success) {
try {
FileUtils.copyURLToFile(new URL(url), tmpFile);
success = true;
} catch(Exception e){
maxRetry -= 1;
}
}
} else {
success = true;
}
if (!success) throw new Exception("$url Download failed!");
}

@BeforeClass
public static void downloadFile() throws Exception {
String tempDirPath = System.getProperty("java.io.tmpdir");
imLocation = tempDirPath + "/inputImages/Pug-Cookie.jpg";
downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
imLocation, 3);
}

@Test
public void testImageProcess() {
NDArray nd = Image.imRead(imLocation, 1, true);
assertArrayEquals(nd.shape().toArray(), new int[]{576, 1024, 3});
NDArray nd2 = Image.imResize(nd, 224, 224, null);
assertArrayEquals(nd2.shape().toArray(), new int[]{224, 224, 3});
NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224);
Image.toImage(cropped);
}
}
1 change: 1 addition & 0 deletions scala-package/examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

<properties>
<skipTests>true</skipTests>
<skipJavaTests>${skipTests}</skipJavaTests>
</properties>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.imageio.ImageIO;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.io.BufferedReader;
import java.io.File;
Expand All @@ -47,76 +45,7 @@ public class PredictorExample {
private String inputImagePath = "/images/dog.jpg";

final static Logger logger = LoggerFactory.getLogger(PredictorExample.class);

/**
* Load the image from file to buffered image
* It can be replaced by loadImageFromFile from ObjectDetector
* @param inputImagePath input image Path in String
* @return Buffered image
*/
private static BufferedImage loadIamgeFromFile(String inputImagePath) {
BufferedImage buf = null;
try {
buf = ImageIO.read(new File(inputImagePath));
} catch (IOException e) {
System.err.println(e);
}
return buf;
}

/**
* Reshape the current image using ImageIO and Graph2D
* It can be replaced by reshapeImage from ObjectDetector
* @param buf Buffered image
* @param newWidth desired width
* @param newHeight desired height
* @return a reshaped bufferedImage
*/
private static BufferedImage reshapeImage(BufferedImage buf, int newWidth, int newHeight) {
BufferedImage resizedImage = new BufferedImage(newWidth, newHeight, BufferedImage.TYPE_INT_RGB);
Graphics2D g = resizedImage.createGraphics();
g.drawImage(buf, 0, 0, newWidth, newHeight, null);
g.dispose();
return resizedImage;
}

/**
* Convert an image from a buffered image into pixels float array
* It can be replaced by bufferedImageToPixels from ObjectDetector
* @param buf buffered image
* @return Float array
*/
private static float[] imagePreprocess(BufferedImage buf) {
// Get height and width of the image
int w = buf.getWidth();
int h = buf.getHeight();

// get an array of integer pixels in the default RGB color mode
int[] pixels = buf.getRGB(0, 0, w, h, null, 0, w);

// 3 times height and width for R,G,B channels
float[] result = new float[3 * h * w];

int row = 0;
// copy pixels to array vertically
while (row < h) {
int col = 0;
// copy pixels to array horizontally
while (col < w) {
int rgb = pixels[row * w + col];
// getting red color
result[0 * h * w + row * w + col] = (rgb >> 16) & 0xFF;
// getting green color
result[1 * h * w + row * w + col] = (rgb >> 8) & 0xFF;
// getting blue color
result[2 * h * w + row * w + col] = rgb & 0xFF;
col += 1;
}
row += 1;
}
buf.flush();
return result;
}
private static NDArray$ NDArray = NDArray$.MODULE$;

/**
* Helper class to print the maximum prediction result
Expand Down Expand Up @@ -170,22 +99,21 @@ public static void main(String[] args) {
inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
Predictor predictor = new Predictor(inst.modelPathPrefix, inputDesc, context,0);
// Prepare data
BufferedImage img = loadIamgeFromFile(inst.inputImagePath);

img = reshapeImage(img, 224, 224);
NDArray img = Image.imRead(inst.inputImagePath, 1, true);
img = Image.imResize(img, 224, 224, null);
// predict
float[][] result = predictor.predict(new float[][]{imagePreprocess(img)});
float[][] result = predictor.predict(new float[][]{img.toArray()});
try {
System.out.println("Predict with Float input");
System.out.println(printMaximumClass(result[0], inst.modelPathPrefix));
} catch (IOException e) {
System.err.println(e);
}
// predict with NDArray
NDArray nd = new NDArray(
imagePreprocess(img),
new Shape(new int[]{1, 3, 224, 224}),
Context.cpu());
NDArray nd = img;
nd = NDArray.transpose(nd, new Shape(new int[]{2, 0, 1}), null)[0];
nd = NDArray.expand_dims(nd, 0, null)[0];
nd = nd.asType(DType.Float32());
List<NDArray> ndList = new ArrayList<>();
ndList.add(nd);
List<NDArray> ndResult = predictor.predictWithNDArray(ndList);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import org.apache.commons.io.FileUtils

object Util {

def downloadUrl(url: String, filePath: String, maxRetry: Option[Int] = None) : Unit = {
def downloadUrl(url: String, filePath: String, maxRetry: Int = 3) : Unit = {
val tmpFile = new File(filePath)
var retry = maxRetry.getOrElse(3)
var retry = maxRetry
var success = false
if (!tmpFile.exists()) {
while (retry > 0 && !success) {
Expand Down
Loading

0 comments on commit 9b8fbf4

Please sign in to comment.