From 9a3e4a02ded9c6d2c304557140dac0c9991d507e Mon Sep 17 00:00:00 2001 From: Lanking Date: Thu, 31 Jan 2019 10:55:12 -0800 Subject: [PATCH] [MXNET-1180] Java Image API (#13807) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add java example * add test and change PredictorExample * add image change * Add minor fixes * add License * add predictor Example tests * fix the issue with JUnit test * Satisfy Lint God ʕ •ᴥ•ʔ * update the pom file config * update documentation * add simplified methods --- scala-package/core/pom.xml | 6 - .../main/scala/org/apache/mxnet/Image.scala | 6 +- .../main/scala/org/apache/mxnet/NDArray.scala | 6 +- .../org/apache/mxnet/javaapi/Image.scala | 114 ++++++++++++++++++ .../org/apache/mxnet/javaapi/ImageTest.java | 67 ++++++++++ scala-package/examples/pom.xml | 1 + .../infer/predictor/PredictorExample.java | 88 ++------------ .../scala/org/apache/mxnetexamples/Util.scala | 4 +- .../infer/predictor/PredictorExampleTest.java | 67 ++++++++++ scala-package/infer/pom.xml | 8 -- .../mxnet/javaapi/JavaNDArrayMacro.scala | 4 +- scala-package/pom.xml | 6 + 12 files changed, 274 insertions(+), 103 deletions(-) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala create mode 100644 scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java create mode 100644 scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml index 7264c39e84a0..4de65c0b361a 100644 --- a/scala-package/core/pom.xml +++ b/scala-package/core/pom.xml @@ -138,12 +138,6 @@ INTERNAL provided - - junit - junit - 4.11 - test - commons-io commons-io diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala index 77881ab940be..0f756e24027f 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala @@ -37,7 +37,7 @@ object Image { * @param flag Convert decoded image to grayscale (0) or color (1). * @param to_rgb Whether to convert decoded image * to mxnet's default RGB format (instead of opencv's default BGR). - * @return NDArray in HWC format + * @return NDArray in HWC format with DType [[DType.UInt8]] */ def imDecode(buf: Array[Byte], flag: Int, to_rgb: Boolean, @@ -56,7 +56,7 @@ object Image { /** * Same imageDecode with InputStream * @param inputStream the inputStream of the image - * @return NDArray in HWC format + * @return NDArray in HWC format with DType [[DType.UInt8]] */ def imDecode(inputStream: InputStream, flag: Int = 1, to_rgb: Boolean = true, @@ -78,7 +78,7 @@ object Image { * @param flag Convert decoded image to grayscale (0) or color (1). * @param to_rgb Whether to convert decoded image to mxnet's default RGB format * (instead of opencv's default BGR). - * @return org.apache.mxnet.NDArray in HWC format + * @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]] */ def imRead(filename: String, flag: Option[Int] = None, to_rgb: Option[Boolean] = None, diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 5c345f21faf4..4324b3dbe63e 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -97,9 +97,11 @@ object NDArray extends NDArrayBase { case ndArr: Seq[NDArray @unchecked] => if (ndArr.head.isInstanceOf[NDArray]) (ndArr.toArray, ndArr.toArray.map(_.handle)) else throw new IllegalArgumentException( - "Unsupported out var type, should be NDArray or subclass of Seq[NDArray]") + s"""Unsupported out ${output.getClass} type, + | should be NDArray or subclass of Seq[NDArray]""".stripMargin) case _ => throw new IllegalArgumentException( - "Unsupported out var type, should be NDArray or subclass of Seq[NDArray]") + s"""Unsupported out ${output.getClass} type, + | should be NDArray or subclass of Seq[NDArray]""".stripMargin) } } else { (null, null) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala new file mode 100644 index 000000000000..7d6f31e930ad --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.javaapi +// scalastyle:off +import java.awt.image.BufferedImage +// scalastyle:on +import java.io.InputStream + +object Image { + /** + * Decode image with OpenCV. + * Note: return image in RGB by default, instead of OpenCV's default BGR. + * @param buf Buffer containing binary encoded image + * @param flag Convert decoded image to grayscale (0) or color (1). + * @param toRGB Whether to convert decoded image + * to mxnet's default RGB format (instead of opencv's default BGR). + * @return NDArray in HWC format with DType [[DType.UInt8]] + */ + def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = { + org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None) + } + + def imDecode(buf: Array[Byte]): NDArray = { + imDecode(buf, 1, true) + } + + /** + * Same imageDecode with InputStream + * + * @param inputStream the inputStream of the image + * @param flag Convert decoded image to grayscale (0) or color (1). + * @param toRGB Whether to convert decoded image + * @return NDArray in HWC format with DType [[DType.UInt8]] + */ + def imDecode(inputStream: InputStream, flag: Int, toRGB: Boolean): NDArray = { + org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None) + } + + def imDecode(inputStream: InputStream): NDArray = { + imDecode(inputStream, 1, true) + } + + /** + * Read and decode image with OpenCV. + * Note: return image in RGB by default, instead of OpenCV's default BGR. + * @param filename Name of the image file to be loaded. + * @param flag Convert decoded image to grayscale (0) or color (1). + * @param toRGB Whether to convert decoded image to mxnet's default RGB format + * (instead of opencv's default BGR). + * @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]] + */ + def imRead(filename: String, flag: Int, toRGB: Boolean): NDArray = { + org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None) + } + + def imRead(filename: String): NDArray = { + imRead(filename, 1, true) + } + + /** + * Resize image with OpenCV. + * @param src source image in NDArray + * @param w Width of resized image. + * @param h Height of resized image. + * @param interp Interpolation method (default=cv2.INTER_LINEAR). + * @return org.apache.mxnet.NDArray + */ + def imResize(src: NDArray, w: Int, h: Int, interp: Integer): NDArray = { + val interpVal = if (interp == null) None else Some(interp.intValue()) + org.apache.mxnet.Image.imResize(src, w, h, interpVal, None) + } + + def imResize(src: NDArray, w: Int, h: Int): NDArray = { + imResize(src, w, h, null) + } + + /** + * Do a fixed crop on the image + * @param src Src image in NDArray + * @param x0 starting x point + * @param y0 starting y point + * @param w width of the image + * @param h height of the image + * @return cropped NDArray + */ + def fixedCrop(src: NDArray, x0: Int, y0: Int, w: Int, h: Int): NDArray = { + org.apache.mxnet.Image.fixedCrop(src, x0, y0, w, h) + } + + /** + * Convert a NDArray image to a real image + * The time cost will increase if the image resolution is big + * @param src Source image file in RGB + * @return Buffered Image + */ + def toImage(src: NDArray): BufferedImage = { + org.apache.mxnet.Image.toImage(src) + } +} diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java new file mode 100644 index 000000000000..0092744a21a8 --- /dev/null +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.javaapi; + +import org.apache.commons.io.FileUtils; +import org.junit.BeforeClass; +import org.junit.Test; +import java.io.File; +import java.net.URL; + +import static org.junit.Assert.assertArrayEquals; + +public class ImageTest { + + private static String imLocation; + + private static void downloadUrl(String url, String filePath, int maxRetry) throws Exception{ + File tmpFile = new File(filePath); + Boolean success = false; + if (!tmpFile.exists()) { + while (maxRetry > 0 && !success) { + try { + FileUtils.copyURLToFile(new URL(url), tmpFile); + success = true; + } catch(Exception e){ + maxRetry -= 1; + } + } + } else { + success = true; + } + if (!success) throw new Exception("$url Download failed!"); + } + + @BeforeClass + public static void downloadFile() throws Exception { + String tempDirPath = System.getProperty("java.io.tmpdir"); + imLocation = tempDirPath + "/inputImages/Pug-Cookie.jpg"; + downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg", + imLocation, 3); + } + + @Test + public void testImageProcess() { + NDArray nd = Image.imRead(imLocation, 1, true); + assertArrayEquals(nd.shape().toArray(), new int[]{576, 1024, 3}); + NDArray nd2 = Image.imResize(nd, 224, 224, null); + assertArrayEquals(nd2.shape().toArray(), new int[]{224, 224, 3}); + NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224); + Image.toImage(cropped); + } +} diff --git a/scala-package/examples/pom.xml b/scala-package/examples/pom.xml index 564102a9f696..30ccfdcce12e 100644 --- a/scala-package/examples/pom.xml +++ b/scala-package/examples/pom.xml @@ -15,6 +15,7 @@ true + ${skipTests} diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java index c9b4426f52b3..c5d209998d32 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java @@ -24,8 +24,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import javax.imageio.ImageIO; -import java.awt.Graphics2D; import java.awt.image.BufferedImage; import java.io.BufferedReader; import java.io.File; @@ -47,76 +45,7 @@ public class PredictorExample { private String inputImagePath = "/images/dog.jpg"; final static Logger logger = LoggerFactory.getLogger(PredictorExample.class); - - /** - * Load the image from file to buffered image - * It can be replaced by loadImageFromFile from ObjectDetector - * @param inputImagePath input image Path in String - * @return Buffered image - */ - private static BufferedImage loadIamgeFromFile(String inputImagePath) { - BufferedImage buf = null; - try { - buf = ImageIO.read(new File(inputImagePath)); - } catch (IOException e) { - System.err.println(e); - } - return buf; - } - - /** - * Reshape the current image using ImageIO and Graph2D - * It can be replaced by reshapeImage from ObjectDetector - * @param buf Buffered image - * @param newWidth desired width - * @param newHeight desired height - * @return a reshaped bufferedImage - */ - private static BufferedImage reshapeImage(BufferedImage buf, int newWidth, int newHeight) { - BufferedImage resizedImage = new BufferedImage(newWidth, newHeight, BufferedImage.TYPE_INT_RGB); - Graphics2D g = resizedImage.createGraphics(); - g.drawImage(buf, 0, 0, newWidth, newHeight, null); - g.dispose(); - return resizedImage; - } - - /** - * Convert an image from a buffered image into pixels float array - * It can be replaced by bufferedImageToPixels from ObjectDetector - * @param buf buffered image - * @return Float array - */ - private static float[] imagePreprocess(BufferedImage buf) { - // Get height and width of the image - int w = buf.getWidth(); - int h = buf.getHeight(); - - // get an array of integer pixels in the default RGB color mode - int[] pixels = buf.getRGB(0, 0, w, h, null, 0, w); - - // 3 times height and width for R,G,B channels - float[] result = new float[3 * h * w]; - - int row = 0; - // copy pixels to array vertically - while (row < h) { - int col = 0; - // copy pixels to array horizontally - while (col < w) { - int rgb = pixels[row * w + col]; - // getting red color - result[0 * h * w + row * w + col] = (rgb >> 16) & 0xFF; - // getting green color - result[1 * h * w + row * w + col] = (rgb >> 8) & 0xFF; - // getting blue color - result[2 * h * w + row * w + col] = rgb & 0xFF; - col += 1; - } - row += 1; - } - buf.flush(); - return result; - } + private static NDArray$ NDArray = NDArray$.MODULE$; /** * Helper class to print the maximum prediction result @@ -170,11 +99,10 @@ public static void main(String[] args) { inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW")); Predictor predictor = new Predictor(inst.modelPathPrefix, inputDesc, context,0); // Prepare data - BufferedImage img = loadIamgeFromFile(inst.inputImagePath); - - img = reshapeImage(img, 224, 224); + NDArray img = Image.imRead(inst.inputImagePath, 1, true); + img = Image.imResize(img, 224, 224, null); // predict - float[][] result = predictor.predict(new float[][]{imagePreprocess(img)}); + float[][] result = predictor.predict(new float[][]{img.toArray()}); try { System.out.println("Predict with Float input"); System.out.println(printMaximumClass(result[0], inst.modelPathPrefix)); @@ -182,10 +110,10 @@ public static void main(String[] args) { System.err.println(e); } // predict with NDArray - NDArray nd = new NDArray( - imagePreprocess(img), - new Shape(new int[]{1, 3, 224, 224}), - Context.cpu()); + NDArray nd = img; + nd = NDArray.transpose(nd, new Shape(new int[]{2, 0, 1}), null)[0]; + nd = NDArray.expand_dims(nd, 0, null)[0]; + nd = nd.asType(DType.Float32()); List ndList = new ArrayList<>(); ndList.add(nd); List ndResult = predictor.predictWithNDArray(ndList); diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala index c1ff10c6c8a2..dba343160bff 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala @@ -24,9 +24,9 @@ import org.apache.commons.io.FileUtils object Util { - def downloadUrl(url: String, filePath: String, maxRetry: Option[Int] = None) : Unit = { + def downloadUrl(url: String, filePath: String, maxRetry: Int = 3) : Unit = { val tmpFile = new File(filePath) - var retry = maxRetry.getOrElse(3) + var retry = maxRetry var success = false if (!tmpFile.exists()) { while (retry > 0 && !success) { diff --git a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java new file mode 100644 index 000000000000..30bc8db447d8 --- /dev/null +++ b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnetexamples.javaapi.infer.predictor; + +import org.junit.BeforeClass; +import org.junit.Test; +import org.apache.mxnetexamples.Util; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; + +public class PredictorExampleTest { + + final static Logger logger = LoggerFactory.getLogger(PredictorExampleTest.class); + private static String modelPathPrefix = ""; + private static String inputImagePath = ""; + + @BeforeClass + public static void downloadFile() { + logger.info("Downloading resnet-18 model"); + + String tempDirPath = System.getProperty("java.io.tmpdir"); + logger.info("tempDirPath: %s".format(tempDirPath)); + + String baseUrl = "https://s3.us-east-2.amazonaws.com/scala-infer-models"; + + Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-symbol.json", + tempDirPath + "/resnet18/resnet-18-symbol.json", 3); + Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-0000.params", + tempDirPath + "/resnet18/resnet-18-0000.params", 3); + Util.downloadUrl(baseUrl + "/resnet-18/synset.txt", + tempDirPath + "/resnet18/synset.txt", 3); + Util.downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg", + tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg", 3); + + modelPathPrefix = tempDirPath + File.separator + "resnet18/resnet-18"; + inputImagePath = tempDirPath + File.separator + + "inputImages/resnet18/Pug-Cookie.jpg"; + } + + @Test + public void testPredictor(){ + PredictorExample example = new PredictorExample(); + String[] args = new String[]{ + "--model-path-prefix", modelPathPrefix, + "--input-image", inputImagePath + }; + example.main(args); + } + +} diff --git a/scala-package/infer/pom.xml b/scala-package/infer/pom.xml index 13ceebb83cd9..565ac6e393a5 100644 --- a/scala-package/infer/pom.xml +++ b/scala-package/infer/pom.xml @@ -64,13 +64,5 @@ 1.10.19 test - - - junit - junit - 4.11 - test - - diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala index 4dfd6eb044a1..fa3565b4fb8e 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala @@ -96,9 +96,9 @@ private[mxnet] object JavaNDArrayMacro extends GeneratorBase { // add default out parameter argDef += s"out: org.apache.mxnet.javaapi.NDArray" if (useParamObject) { - impl += "if (po.getOut() != null) map(\"out\") = po.getOut()" + impl += "if (po.getOut() != null) map(\"out\") = po.getOut().nd" } else { - impl += "if (out != null) map(\"out\") = out" + impl += "if (out != null) map(\"out\") = out.nd" } val returnType = "Array[org.apache.mxnet.javaapi.NDArray]" // scalastyle:off diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 07baeab9b63f..5ba6f1f9f498 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -412,6 +412,12 @@ 1.13.5 test + + junit + junit + 4.11 + test +