From 706a2222cd44341742e22f05cbcd47e227736f14 Mon Sep 17 00:00:00 2001 From: Qing Date: Sat, 5 Jan 2019 11:51:02 -0800 Subject: [PATCH 01/11] add java example --- .../org/apache/mxnet/javaapi/Image.scala | 96 +++++++++++++++++++ .../infer/predictor/PredictorExample.java | 77 +-------------- 2 files changed, 98 insertions(+), 75 deletions(-) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala new file mode 100644 index 000000000000..9276452c3ba0 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.javaapi +// scalastyle:off +import java.awt.image.BufferedImage +// scalastyle:on +import java.io.InputStream + +object Image { + /** + * Decode image with OpenCV. + * Note: return image in RGB by default, instead of OpenCV's default BGR. + * @param buf Buffer containing binary encoded image + * @param flag Convert decoded image to grayscale (0) or color (1). + * @param toRGB Whether to convert decoded image + * to mxnet's default RGB format (instead of opencv's default BGR). + * @return NDArray in HWC format + */ + def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean, out: NDArray): NDArray = { + org.apache.mxnet.Image.imDecode(buf, flag, toRGB, Some(out)) + } + + /** + * Same imageDecode with InputStream + * @param inputStream the inputStream of the image + * @return NDArray in HWC format + */ + def imDecode(inputStream: InputStream, flag: Int = 1, toRGB: Boolean = true, + out: NDArray): NDArray = { + org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, Some(out)) + } + + /** + * Read and decode image with OpenCV. + * Note: return image in RGB by default, instead of OpenCV's default BGR. + * @param filename Name of the image file to be loaded. + * @param flag Convert decoded image to grayscale (0) or color (1). + * @param toRGB Whether to convert decoded image to mxnet's default RGB format + * (instead of opencv's default BGR). + * @return org.apache.mxnet.NDArray in HWC format + */ + def imRead(filename: String, flag: Int, toRGB: Boolean = true, out: NDArray): NDArray = { + org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), Some(out)) + } + + /** + * Resize image with OpenCV. + * @param src source image in NDArray + * @param w Width of resized image. + * @param h Height of resized image. + * @param interp Interpolation method (default=cv2.INTER_LINEAR). + * @return org.apache.mxnet.NDArray + */ + def imResize(src: NDArray, w: Int, h: Int, + interp: Integer, out: NDArray): NDArray = { + org.apache.mxnet.Image.imResize(src, w, h, Some(interp), Some(out)) + } + + /** + * Do a fixed crop on the image + * @param src Src image in NDArray + * @param x0 starting x point + * @param y0 starting y point + * @param w width of the image + * @param h height of the image + * @return cropped NDArray + */ + def fixedCrop(src: NDArray, x0: Int, y0: Int, w: Int, h: Int): NDArray = { + org.apache.mxnet.Image.fixedCrop(src, x0, y0, w, h) + } + + /** + * Convert a NDArray image to a real image + * The time cost will increase if the image resolution is big + * @param src Source image file in RGB + * @return Buffered Image + */ + def toImage(src: NDArray): BufferedImage = { + org.apache.mxnet.Image.toImage(src) + } +} diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java index c9b4426f52b3..d89835c8dcaa 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java @@ -24,8 +24,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import javax.imageio.ImageIO; -import java.awt.Graphics2D; import java.awt.image.BufferedImage; import java.io.BufferedReader; import java.io.File; @@ -48,76 +46,6 @@ public class PredictorExample { final static Logger logger = LoggerFactory.getLogger(PredictorExample.class); - /** - * Load the image from file to buffered image - * It can be replaced by loadImageFromFile from ObjectDetector - * @param inputImagePath input image Path in String - * @return Buffered image - */ - private static BufferedImage loadIamgeFromFile(String inputImagePath) { - BufferedImage buf = null; - try { - buf = ImageIO.read(new File(inputImagePath)); - } catch (IOException e) { - System.err.println(e); - } - return buf; - } - - /** - * Reshape the current image using ImageIO and Graph2D - * It can be replaced by reshapeImage from ObjectDetector - * @param buf Buffered image - * @param newWidth desired width - * @param newHeight desired height - * @return a reshaped bufferedImage - */ - private static BufferedImage reshapeImage(BufferedImage buf, int newWidth, int newHeight) { - BufferedImage resizedImage = new BufferedImage(newWidth, newHeight, BufferedImage.TYPE_INT_RGB); - Graphics2D g = resizedImage.createGraphics(); - g.drawImage(buf, 0, 0, newWidth, newHeight, null); - g.dispose(); - return resizedImage; - } - - /** - * Convert an image from a buffered image into pixels float array - * It can be replaced by bufferedImageToPixels from ObjectDetector - * @param buf buffered image - * @return Float array - */ - private static float[] imagePreprocess(BufferedImage buf) { - // Get height and width of the image - int w = buf.getWidth(); - int h = buf.getHeight(); - - // get an array of integer pixels in the default RGB color mode - int[] pixels = buf.getRGB(0, 0, w, h, null, 0, w); - - // 3 times height and width for R,G,B channels - float[] result = new float[3 * h * w]; - - int row = 0; - // copy pixels to array vertically - while (row < h) { - int col = 0; - // copy pixels to array horizontally - while (col < w) { - int rgb = pixels[row * w + col]; - // getting red color - result[0 * h * w + row * w + col] = (rgb >> 16) & 0xFF; - // getting green color - result[1 * h * w + row * w + col] = (rgb >> 8) & 0xFF; - // getting blue color - result[2 * h * w + row * w + col] = rgb & 0xFF; - col += 1; - } - row += 1; - } - buf.flush(); - return result; - } - /** * Helper class to print the maximum prediction result * @param probabilities The float array of probability @@ -170,9 +98,8 @@ public static void main(String[] args) { inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW")); Predictor predictor = new Predictor(inst.modelPathPrefix, inputDesc, context,0); // Prepare data - BufferedImage img = loadIamgeFromFile(inst.inputImagePath); - - img = reshapeImage(img, 224, 224); + NDArray img = Image.imRead(inst.inputImagePath, 1, true, null); + img = Image.imResize(img, 224, 224, null, null); // predict float[][] result = predictor.predict(new float[][]{imagePreprocess(img)}); try { From c7f7e0875ee739f260802e161cdd88e6731c5dfe Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 7 Jan 2019 16:39:42 -0800 Subject: [PATCH 02/11] add test and change PredictorExample --- .../org/apache/mxnet/javaapi/ImageTest.java | 50 +++++++++++++++++++ .../infer/predictor/PredictorExample.java | 7 +-- 2 files changed, 52 insertions(+), 5 deletions(-) create mode 100644 scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java new file mode 100644 index 000000000000..7a028be30527 --- /dev/null +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java @@ -0,0 +1,50 @@ +package org.apache.mxnet.javaapi; + +import org.apache.commons.io.FileUtils; +import org.junit.BeforeClass; +import org.junit.Test; +import java.io.File; +import java.net.URL; + +public class ImageTest { + + private String imLocation; + + private void downloadUrl(String url, String filePath, int maxRetry) throws Exception{ + File tmpFile = new File(filePath); + Boolean success = false; + if (!tmpFile.exists()) { + while (maxRetry > 0 && !success) { + try { + FileUtils.copyURLToFile(new URL(url), tmpFile); + success = true; + } catch(Exception e){ + maxRetry -= 1; + } + } + } else { + success = true; + } + if (!success) throw new Exception("$url Download failed!"); + } + + @BeforeClass + public void downloadFile() throws Exception { + String tempDirPath = System.getProperty("java.io.tmpdir"); + imLocation = tempDirPath + "/inputImages/Pug-Cookie.jpg"; + try { + downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg", + imLocation, 3); + } catch (Exception e) { + throw e; + } + } + + @Test + public void testImageProcess() { + NDArray nd = Image.imRead(imLocation, 1, true, null); + NDArray nd2 = Image.imResize(nd, 224, 224, null, null); + NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224); + Image.toImage(cropped); + } +} diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java index d89835c8dcaa..a1c3401c2378 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java @@ -101,7 +101,7 @@ public static void main(String[] args) { NDArray img = Image.imRead(inst.inputImagePath, 1, true, null); img = Image.imResize(img, 224, 224, null, null); // predict - float[][] result = predictor.predict(new float[][]{imagePreprocess(img)}); + float[][] result = predictor.predict(new float[][]{img.toArray()}); try { System.out.println("Predict with Float input"); System.out.println(printMaximumClass(result[0], inst.modelPathPrefix)); @@ -109,10 +109,7 @@ public static void main(String[] args) { System.err.println(e); } // predict with NDArray - NDArray nd = new NDArray( - imagePreprocess(img), - new Shape(new int[]{1, 3, 224, 224}), - Context.cpu()); + NDArray nd = img; List ndList = new ArrayList<>(); ndList.add(nd); List ndResult = predictor.predictWithNDArray(ndList); From 195c5d9ee2cecca7898608f37a8bc1e07796f28f Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 7 Jan 2019 17:22:39 -0800 Subject: [PATCH 03/11] add image change --- .../org/apache/mxnet/javaapi/Image.scala | 19 +++++++++---------- .../org/apache/mxnet/javaapi/ImageTest.java | 14 +++++++++----- .../infer/predictor/PredictorExample.java | 4 ++-- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala index 9276452c3ba0..9d5216cbc862 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala @@ -31,8 +31,8 @@ object Image { * to mxnet's default RGB format (instead of opencv's default BGR). * @return NDArray in HWC format */ - def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean, out: NDArray): NDArray = { - org.apache.mxnet.Image.imDecode(buf, flag, toRGB, Some(out)) + def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = { + org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None) } /** @@ -40,9 +40,8 @@ object Image { * @param inputStream the inputStream of the image * @return NDArray in HWC format */ - def imDecode(inputStream: InputStream, flag: Int = 1, toRGB: Boolean = true, - out: NDArray): NDArray = { - org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, Some(out)) + def imDecode(inputStream: InputStream, flag: Int = 1, toRGB: Boolean = true): NDArray = { + org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None) } /** @@ -54,8 +53,8 @@ object Image { * (instead of opencv's default BGR). * @return org.apache.mxnet.NDArray in HWC format */ - def imRead(filename: String, flag: Int, toRGB: Boolean = true, out: NDArray): NDArray = { - org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), Some(out)) + def imRead(filename: String, flag: Int, toRGB: Boolean = true): NDArray = { + org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None) } /** @@ -66,9 +65,9 @@ object Image { * @param interp Interpolation method (default=cv2.INTER_LINEAR). * @return org.apache.mxnet.NDArray */ - def imResize(src: NDArray, w: Int, h: Int, - interp: Integer, out: NDArray): NDArray = { - org.apache.mxnet.Image.imResize(src, w, h, Some(interp), Some(out)) + def imResize(src: NDArray, w: Int, h: Int, interp: Integer): NDArray = { + val interpVal = if (interp == null) None else Some(interp.intValue()) + org.apache.mxnet.Image.imResize(src, w, h, interpVal, None) } /** diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java index 7a028be30527..049b1d6156be 100644 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java @@ -6,11 +6,13 @@ import java.io.File; import java.net.URL; +import static org.junit.Assert.assertArrayEquals; + public class ImageTest { - private String imLocation; + private static String imLocation; - private void downloadUrl(String url, String filePath, int maxRetry) throws Exception{ + private static void downloadUrl(String url, String filePath, int maxRetry) throws Exception{ File tmpFile = new File(filePath); Boolean success = false; if (!tmpFile.exists()) { @@ -29,7 +31,7 @@ private void downloadUrl(String url, String filePath, int maxRetry) throws Excep } @BeforeClass - public void downloadFile() throws Exception { + public static void downloadFile() throws Exception { String tempDirPath = System.getProperty("java.io.tmpdir"); imLocation = tempDirPath + "/inputImages/Pug-Cookie.jpg"; try { @@ -42,8 +44,10 @@ public void downloadFile() throws Exception { @Test public void testImageProcess() { - NDArray nd = Image.imRead(imLocation, 1, true, null); - NDArray nd2 = Image.imResize(nd, 224, 224, null, null); + NDArray nd = Image.imRead(imLocation, 1, true); + assertArrayEquals(nd.shape().toArray(), new int[]{576, 1024, 3}); + NDArray nd2 = Image.imResize(nd, 224, 224, null); + assertArrayEquals(nd.shape().toArray(), new int[]{224, 224, 3}); NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224); Image.toImage(cropped); } diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java index a1c3401c2378..4559315866ba 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java @@ -98,8 +98,8 @@ public static void main(String[] args) { inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW")); Predictor predictor = new Predictor(inst.modelPathPrefix, inputDesc, context,0); // Prepare data - NDArray img = Image.imRead(inst.inputImagePath, 1, true, null); - img = Image.imResize(img, 224, 224, null, null); + NDArray img = Image.imRead(inst.inputImagePath, 1, true); + img = Image.imResize(img, 224, 224, null); // predict float[][] result = predictor.predict(new float[][]{img.toArray()}); try { From 9ef19b5056f321aa10fd197c50e02970801b4be3 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 9 Jan 2019 10:40:22 -0800 Subject: [PATCH 04/11] Add minor fixes --- .../core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java index 049b1d6156be..da4911f9a865 100644 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java @@ -47,7 +47,7 @@ public void testImageProcess() { NDArray nd = Image.imRead(imLocation, 1, true); assertArrayEquals(nd.shape().toArray(), new int[]{576, 1024, 3}); NDArray nd2 = Image.imResize(nd, 224, 224, null); - assertArrayEquals(nd.shape().toArray(), new int[]{224, 224, 3}); + assertArrayEquals(nd2.shape().toArray(), new int[]{224, 224, 3}); NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224); Image.toImage(cropped); } From 016d37424275a3c99775a1f316a759932920230c Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 9 Jan 2019 11:46:04 -0800 Subject: [PATCH 05/11] add License --- .../org/apache/mxnet/javaapi/ImageTest.java | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java index da4911f9a865..cc284c7c6910 100644 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.mxnet.javaapi; import org.apache.commons.io.FileUtils; From f2f7483f725595d75a24e2aeba43dd90e0c97588 Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 11 Jan 2019 11:23:11 -0800 Subject: [PATCH 06/11] add predictor Example tests --- scala-package/examples/pom.xml | 7 +++ .../scala/org/apache/mxnetexamples/Util.scala | 4 +- .../predictor/PredictorExampleSuite.java | 50 +++++++++++++++++++ 3 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleSuite.java diff --git a/scala-package/examples/pom.xml b/scala-package/examples/pom.xml index 564102a9f696..a41f6cb0aa10 100644 --- a/scala-package/examples/pom.xml +++ b/scala-package/examples/pom.xml @@ -15,6 +15,7 @@ true + ${skipTests} @@ -128,5 +129,11 @@ slf4j-simple 1.7.5 + + junit + junit + 4.11 + test + diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala index c1ff10c6c8a2..dba343160bff 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala @@ -24,9 +24,9 @@ import org.apache.commons.io.FileUtils object Util { - def downloadUrl(url: String, filePath: String, maxRetry: Option[Int] = None) : Unit = { + def downloadUrl(url: String, filePath: String, maxRetry: Int = 3) : Unit = { val tmpFile = new File(filePath) - var retry = maxRetry.getOrElse(3) + var retry = maxRetry var success = false if (!tmpFile.exists()) { while (retry > 0 && !success) { diff --git a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleSuite.java b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleSuite.java new file mode 100644 index 000000000000..e4cc1c73103a --- /dev/null +++ b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleSuite.java @@ -0,0 +1,50 @@ +package org.apache.mxnetexamples.javaapi.infer.predictor; + +import org.junit.BeforeClass; +import org.junit.Test; +import org.apache.mxnetexamples.Util; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; + +public class PredictorExampleSuite { + + final static Logger logger = LoggerFactory.getLogger(PredictorExampleSuite.class); + private static String modelPathPrefix = ""; + private static String inputImagePath = ""; + + @BeforeClass + public static void downloadFile() { + logger.info("Downloading resnet-18 model"); + + String tempDirPath = System.getProperty("java.io.tmpdir"); + logger.info("tempDirPath: %s".format(tempDirPath)); + + String baseUrl = "https://s3.us-east-2.amazonaws.com/scala-infer-models"; + + Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-symbol.json", + tempDirPath + "/resnet18/resnet-18-symbol.json", 3); + Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-0000.params", + tempDirPath + "/resnet18/resnet-18-0000.params", 3); + Util.downloadUrl(baseUrl + "/resnet-18/synset.txt", + tempDirPath + "/resnet18/synset.txt", 3); + Util.downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg", + tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg", 3); + + modelPathPrefix = tempDirPath + File.separator + "resnet18/resnet-18"; + inputImagePath = tempDirPath + File.separator + + "inputImages/resnet18/Pug-Cookie.jpg"; + } + + @Test + public void testPredictor(){ + PredictorExample example = new PredictorExample(); + String[] args = new String[]{ + "--model-path-prefix", modelPathPrefix, + "--input-image", inputImagePath + }; + example.main(args); + } + +} From 79b39e7656d36346f47c6582308af038d837ffc2 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 14 Jan 2019 16:01:23 -0800 Subject: [PATCH 07/11] fix the issue with JUnit test --- .../core/src/main/scala/org/apache/mxnet/Image.scala | 6 +++--- .../core/src/main/scala/org/apache/mxnet/NDArray.scala | 6 ++++-- .../src/main/scala/org/apache/mxnet/javaapi/Image.scala | 6 +++--- .../javaapi/infer/predictor/PredictorExample.java | 4 ++++ ...PredictorExampleSuite.java => PredictorExampleTest.java} | 5 +++-- .../scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala | 4 ++-- 6 files changed, 19 insertions(+), 12 deletions(-) rename scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/{PredictorExampleSuite.java => PredictorExampleTest.java} (95%) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala index 77881ab940be..473cb7c1ae91 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala @@ -37,7 +37,7 @@ object Image { * @param flag Convert decoded image to grayscale (0) or color (1). * @param to_rgb Whether to convert decoded image * to mxnet's default RGB format (instead of opencv's default BGR). - * @return NDArray in HWC format + * @return NDArray in HWC format with DType uint8 */ def imDecode(buf: Array[Byte], flag: Int, to_rgb: Boolean, @@ -56,7 +56,7 @@ object Image { /** * Same imageDecode with InputStream * @param inputStream the inputStream of the image - * @return NDArray in HWC format + * @return NDArray in HWC format with DType uint8 */ def imDecode(inputStream: InputStream, flag: Int = 1, to_rgb: Boolean = true, @@ -78,7 +78,7 @@ object Image { * @param flag Convert decoded image to grayscale (0) or color (1). * @param to_rgb Whether to convert decoded image to mxnet's default RGB format * (instead of opencv's default BGR). - * @return org.apache.mxnet.NDArray in HWC format + * @return org.apache.mxnet.NDArray in HWC format with DType uint8 */ def imRead(filename: String, flag: Option[Int] = None, to_rgb: Option[Boolean] = None, diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 5c345f21faf4..4324b3dbe63e 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -97,9 +97,11 @@ object NDArray extends NDArrayBase { case ndArr: Seq[NDArray @unchecked] => if (ndArr.head.isInstanceOf[NDArray]) (ndArr.toArray, ndArr.toArray.map(_.handle)) else throw new IllegalArgumentException( - "Unsupported out var type, should be NDArray or subclass of Seq[NDArray]") + s"""Unsupported out ${output.getClass} type, + | should be NDArray or subclass of Seq[NDArray]""".stripMargin) case _ => throw new IllegalArgumentException( - "Unsupported out var type, should be NDArray or subclass of Seq[NDArray]") + s"""Unsupported out ${output.getClass} type, + | should be NDArray or subclass of Seq[NDArray]""".stripMargin) } } else { (null, null) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala index 9d5216cbc862..531d0b1f96fa 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala @@ -29,7 +29,7 @@ object Image { * @param flag Convert decoded image to grayscale (0) or color (1). * @param toRGB Whether to convert decoded image * to mxnet's default RGB format (instead of opencv's default BGR). - * @return NDArray in HWC format + * @return NDArray in HWC format with DType uint8 */ def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = { org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None) @@ -38,7 +38,7 @@ object Image { /** * Same imageDecode with InputStream * @param inputStream the inputStream of the image - * @return NDArray in HWC format + * @return NDArray in HWC format with DType uint8 */ def imDecode(inputStream: InputStream, flag: Int = 1, toRGB: Boolean = true): NDArray = { org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None) @@ -51,7 +51,7 @@ object Image { * @param flag Convert decoded image to grayscale (0) or color (1). * @param toRGB Whether to convert decoded image to mxnet's default RGB format * (instead of opencv's default BGR). - * @return org.apache.mxnet.NDArray in HWC format + * @return org.apache.mxnet.NDArray in HWC format with DType uint8 */ def imRead(filename: String, flag: Int, toRGB: Boolean = true): NDArray = { org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None) diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java index 4559315866ba..c5d209998d32 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java @@ -45,6 +45,7 @@ public class PredictorExample { private String inputImagePath = "/images/dog.jpg"; final static Logger logger = LoggerFactory.getLogger(PredictorExample.class); + private static NDArray$ NDArray = NDArray$.MODULE$; /** * Helper class to print the maximum prediction result @@ -110,6 +111,9 @@ public static void main(String[] args) { } // predict with NDArray NDArray nd = img; + nd = NDArray.transpose(nd, new Shape(new int[]{2, 0, 1}), null)[0]; + nd = NDArray.expand_dims(nd, 0, null)[0]; + nd = nd.asType(DType.Float32()); List ndList = new ArrayList<>(); ndList.add(nd); List ndResult = predictor.predictWithNDArray(ndList); diff --git a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleSuite.java b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java similarity index 95% rename from scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleSuite.java rename to scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java index e4cc1c73103a..49656274807d 100644 --- a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleSuite.java +++ b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java @@ -1,5 +1,6 @@ package org.apache.mxnetexamples.javaapi.infer.predictor; +import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import org.apache.mxnetexamples.Util; @@ -8,9 +9,9 @@ import java.io.File; -public class PredictorExampleSuite { +public class PredictorExampleTest { - final static Logger logger = LoggerFactory.getLogger(PredictorExampleSuite.class); + final static Logger logger = LoggerFactory.getLogger(PredictorExampleTest.class); private static String modelPathPrefix = ""; private static String inputImagePath = ""; diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala index 4dfd6eb044a1..fa3565b4fb8e 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala @@ -96,9 +96,9 @@ private[mxnet] object JavaNDArrayMacro extends GeneratorBase { // add default out parameter argDef += s"out: org.apache.mxnet.javaapi.NDArray" if (useParamObject) { - impl += "if (po.getOut() != null) map(\"out\") = po.getOut()" + impl += "if (po.getOut() != null) map(\"out\") = po.getOut().nd" } else { - impl += "if (out != null) map(\"out\") = out" + impl += "if (out != null) map(\"out\") = out.nd" } val returnType = "Array[org.apache.mxnet.javaapi.NDArray]" // scalastyle:off From 4c87ea877734f8d5a2aa1ba3171713dc7f3ddbe7 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 14 Jan 2019 16:30:38 -0800 Subject: [PATCH 08/11] =?UTF-8?q?Satisfy=20Lint=20God=20=CA=95=20=E2=80=A2?= =?UTF-8?q?=E1=B4=A5=E2=80=A2=CA=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../infer/predictor/PredictorExampleTest.java | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java index 49656274807d..30bc8db447d8 100644 --- a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java +++ b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java @@ -1,6 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.mxnetexamples.javaapi.infer.predictor; -import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import org.apache.mxnetexamples.Util; From 48b0826c9959f80cfd6c297718e4969d10c0eb3a Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 16 Jan 2019 11:52:40 -0800 Subject: [PATCH 09/11] update the pom file config --- scala-package/core/pom.xml | 6 ------ .../src/test/java/org/apache/mxnet/javaapi/ImageTest.java | 8 ++------ scala-package/examples/pom.xml | 6 ------ scala-package/infer/pom.xml | 8 -------- scala-package/pom.xml | 6 ++++++ 5 files changed, 8 insertions(+), 26 deletions(-) diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml index 7264c39e84a0..4de65c0b361a 100644 --- a/scala-package/core/pom.xml +++ b/scala-package/core/pom.xml @@ -138,12 +138,6 @@ INTERNAL provided - - junit - junit - 4.11 - test - commons-io commons-io diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java index cc284c7c6910..0092744a21a8 100644 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java @@ -51,12 +51,8 @@ private static void downloadUrl(String url, String filePath, int maxRetry) throw public static void downloadFile() throws Exception { String tempDirPath = System.getProperty("java.io.tmpdir"); imLocation = tempDirPath + "/inputImages/Pug-Cookie.jpg"; - try { - downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg", - imLocation, 3); - } catch (Exception e) { - throw e; - } + downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg", + imLocation, 3); } @Test diff --git a/scala-package/examples/pom.xml b/scala-package/examples/pom.xml index a41f6cb0aa10..30ccfdcce12e 100644 --- a/scala-package/examples/pom.xml +++ b/scala-package/examples/pom.xml @@ -129,11 +129,5 @@ slf4j-simple 1.7.5 - - junit - junit - 4.11 - test - diff --git a/scala-package/infer/pom.xml b/scala-package/infer/pom.xml index 13ceebb83cd9..565ac6e393a5 100644 --- a/scala-package/infer/pom.xml +++ b/scala-package/infer/pom.xml @@ -64,13 +64,5 @@ 1.10.19 test - - - junit - junit - 4.11 - test - - diff --git a/scala-package/pom.xml b/scala-package/pom.xml index 07baeab9b63f..5ba6f1f9f498 100644 --- a/scala-package/pom.xml +++ b/scala-package/pom.xml @@ -412,6 +412,12 @@ 1.13.5 test + + junit + junit + 4.11 + test + From f1dcf99d0083e6e8904548651d43d7d76d299a9c Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 16 Jan 2019 11:58:51 -0800 Subject: [PATCH 10/11] update documentation --- .../core/src/main/scala/org/apache/mxnet/Image.scala | 6 +++--- .../src/main/scala/org/apache/mxnet/javaapi/Image.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala index 473cb7c1ae91..0f756e24027f 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala @@ -37,7 +37,7 @@ object Image { * @param flag Convert decoded image to grayscale (0) or color (1). * @param to_rgb Whether to convert decoded image * to mxnet's default RGB format (instead of opencv's default BGR). - * @return NDArray in HWC format with DType uint8 + * @return NDArray in HWC format with DType [[DType.UInt8]] */ def imDecode(buf: Array[Byte], flag: Int, to_rgb: Boolean, @@ -56,7 +56,7 @@ object Image { /** * Same imageDecode with InputStream * @param inputStream the inputStream of the image - * @return NDArray in HWC format with DType uint8 + * @return NDArray in HWC format with DType [[DType.UInt8]] */ def imDecode(inputStream: InputStream, flag: Int = 1, to_rgb: Boolean = true, @@ -78,7 +78,7 @@ object Image { * @param flag Convert decoded image to grayscale (0) or color (1). * @param to_rgb Whether to convert decoded image to mxnet's default RGB format * (instead of opencv's default BGR). - * @return org.apache.mxnet.NDArray in HWC format with DType uint8 + * @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]] */ def imRead(filename: String, flag: Option[Int] = None, to_rgb: Option[Boolean] = None, diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala index 531d0b1f96fa..c469b01b7ea5 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala @@ -29,7 +29,7 @@ object Image { * @param flag Convert decoded image to grayscale (0) or color (1). * @param toRGB Whether to convert decoded image * to mxnet's default RGB format (instead of opencv's default BGR). - * @return NDArray in HWC format with DType uint8 + * @return NDArray in HWC format with DType [[DType.UInt8]] */ def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = { org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None) @@ -38,7 +38,7 @@ object Image { /** * Same imageDecode with InputStream * @param inputStream the inputStream of the image - * @return NDArray in HWC format with DType uint8 + * @return NDArray in HWC format with DType [[DType.UInt8]] */ def imDecode(inputStream: InputStream, flag: Int = 1, toRGB: Boolean = true): NDArray = { org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None) @@ -51,7 +51,7 @@ object Image { * @param flag Convert decoded image to grayscale (0) or color (1). * @param toRGB Whether to convert decoded image to mxnet's default RGB format * (instead of opencv's default BGR). - * @return org.apache.mxnet.NDArray in HWC format with DType uint8 + * @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]] */ def imRead(filename: String, flag: Int, toRGB: Boolean = true): NDArray = { org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None) From 0f6e376bc5c1c0db7f4547cd759fd0e396790e15 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 30 Jan 2019 13:25:24 -0800 Subject: [PATCH 11/11] add simplified methods --- .../org/apache/mxnet/javaapi/Image.scala | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala index c469b01b7ea5..7d6f31e930ad 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala @@ -35,15 +35,26 @@ object Image { org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None) } + def imDecode(buf: Array[Byte]): NDArray = { + imDecode(buf, 1, true) + } + /** * Same imageDecode with InputStream + * * @param inputStream the inputStream of the image + * @param flag Convert decoded image to grayscale (0) or color (1). + * @param toRGB Whether to convert decoded image * @return NDArray in HWC format with DType [[DType.UInt8]] */ - def imDecode(inputStream: InputStream, flag: Int = 1, toRGB: Boolean = true): NDArray = { + def imDecode(inputStream: InputStream, flag: Int, toRGB: Boolean): NDArray = { org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None) } + def imDecode(inputStream: InputStream): NDArray = { + imDecode(inputStream, 1, true) + } + /** * Read and decode image with OpenCV. * Note: return image in RGB by default, instead of OpenCV's default BGR. @@ -53,10 +64,14 @@ object Image { * (instead of opencv's default BGR). * @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]] */ - def imRead(filename: String, flag: Int, toRGB: Boolean = true): NDArray = { + def imRead(filename: String, flag: Int, toRGB: Boolean): NDArray = { org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None) } + def imRead(filename: String): NDArray = { + imRead(filename, 1, true) + } + /** * Resize image with OpenCV. * @param src source image in NDArray @@ -70,6 +85,10 @@ object Image { org.apache.mxnet.Image.imResize(src, w, h, interpVal, None) } + def imResize(src: NDArray, w: Int, h: Int): NDArray = { + imResize(src, w, h, null) + } + /** * Do a fixed crop on the image * @param src Src image in NDArray