diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml
index 7264c39e84a0..4de65c0b361a 100644
--- a/scala-package/core/pom.xml
+++ b/scala-package/core/pom.xml
@@ -138,12 +138,6 @@
INTERNAL
provided
-
- junit
- junit
- 4.11
- test
-
commons-io
commons-io
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
index 77881ab940be..0f756e24027f 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
@@ -37,7 +37,7 @@ object Image {
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param to_rgb Whether to convert decoded image
* to mxnet's default RGB format (instead of opencv's default BGR).
- * @return NDArray in HWC format
+ * @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(buf: Array[Byte], flag: Int,
to_rgb: Boolean,
@@ -56,7 +56,7 @@ object Image {
/**
* Same imageDecode with InputStream
* @param inputStream the inputStream of the image
- * @return NDArray in HWC format
+ * @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(inputStream: InputStream, flag: Int = 1,
to_rgb: Boolean = true,
@@ -78,7 +78,7 @@ object Image {
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param to_rgb Whether to convert decoded image to mxnet's default RGB format
* (instead of opencv's default BGR).
- * @return org.apache.mxnet.NDArray in HWC format
+ * @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
*/
def imRead(filename: String, flag: Option[Int] = None,
to_rgb: Option[Boolean] = None,
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 5c345f21faf4..4324b3dbe63e 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -97,9 +97,11 @@ object NDArray extends NDArrayBase {
case ndArr: Seq[NDArray @unchecked] =>
if (ndArr.head.isInstanceOf[NDArray]) (ndArr.toArray, ndArr.toArray.map(_.handle))
else throw new IllegalArgumentException(
- "Unsupported out var type, should be NDArray or subclass of Seq[NDArray]")
+ s"""Unsupported out ${output.getClass} type,
+ | should be NDArray or subclass of Seq[NDArray]""".stripMargin)
case _ => throw new IllegalArgumentException(
- "Unsupported out var type, should be NDArray or subclass of Seq[NDArray]")
+ s"""Unsupported out ${output.getClass} type,
+ | should be NDArray or subclass of Seq[NDArray]""".stripMargin)
}
} else {
(null, null)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala
new file mode 100644
index 000000000000..7d6f31e930ad
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi
+// scalastyle:off
+import java.awt.image.BufferedImage
+// scalastyle:on
+import java.io.InputStream
+
+object Image {
+ /**
+ * Decode image with OpenCV.
+ * Note: return image in RGB by default, instead of OpenCV's default BGR.
+ * @param buf Buffer containing binary encoded image
+ * @param flag Convert decoded image to grayscale (0) or color (1).
+ * @param toRGB Whether to convert decoded image
+ * to mxnet's default RGB format (instead of opencv's default BGR).
+ * @return NDArray in HWC format with DType [[DType.UInt8]]
+ */
+ def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = {
+ org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None)
+ }
+
+ def imDecode(buf: Array[Byte]): NDArray = {
+ imDecode(buf, 1, true)
+ }
+
+ /**
+ * Same imageDecode with InputStream
+ *
+ * @param inputStream the inputStream of the image
+ * @param flag Convert decoded image to grayscale (0) or color (1).
+ * @param toRGB Whether to convert decoded image
+ * @return NDArray in HWC format with DType [[DType.UInt8]]
+ */
+ def imDecode(inputStream: InputStream, flag: Int, toRGB: Boolean): NDArray = {
+ org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None)
+ }
+
+ def imDecode(inputStream: InputStream): NDArray = {
+ imDecode(inputStream, 1, true)
+ }
+
+ /**
+ * Read and decode image with OpenCV.
+ * Note: return image in RGB by default, instead of OpenCV's default BGR.
+ * @param filename Name of the image file to be loaded.
+ * @param flag Convert decoded image to grayscale (0) or color (1).
+ * @param toRGB Whether to convert decoded image to mxnet's default RGB format
+ * (instead of opencv's default BGR).
+ * @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
+ */
+ def imRead(filename: String, flag: Int, toRGB: Boolean): NDArray = {
+ org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None)
+ }
+
+ def imRead(filename: String): NDArray = {
+ imRead(filename, 1, true)
+ }
+
+ /**
+ * Resize image with OpenCV.
+ * @param src source image in NDArray
+ * @param w Width of resized image.
+ * @param h Height of resized image.
+ * @param interp Interpolation method (default=cv2.INTER_LINEAR).
+ * @return org.apache.mxnet.NDArray
+ */
+ def imResize(src: NDArray, w: Int, h: Int, interp: Integer): NDArray = {
+ val interpVal = if (interp == null) None else Some(interp.intValue())
+ org.apache.mxnet.Image.imResize(src, w, h, interpVal, None)
+ }
+
+ def imResize(src: NDArray, w: Int, h: Int): NDArray = {
+ imResize(src, w, h, null)
+ }
+
+ /**
+ * Do a fixed crop on the image
+ * @param src Src image in NDArray
+ * @param x0 starting x point
+ * @param y0 starting y point
+ * @param w width of the image
+ * @param h height of the image
+ * @return cropped NDArray
+ */
+ def fixedCrop(src: NDArray, x0: Int, y0: Int, w: Int, h: Int): NDArray = {
+ org.apache.mxnet.Image.fixedCrop(src, x0, y0, w, h)
+ }
+
+ /**
+ * Convert a NDArray image to a real image
+ * The time cost will increase if the image resolution is big
+ * @param src Source image file in RGB
+ * @return Buffered Image
+ */
+ def toImage(src: NDArray): BufferedImage = {
+ org.apache.mxnet.Image.toImage(src)
+ }
+}
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java
new file mode 100644
index 000000000000..0092744a21a8
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import org.apache.commons.io.FileUtils;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import java.io.File;
+import java.net.URL;
+
+import static org.junit.Assert.assertArrayEquals;
+
+public class ImageTest {
+
+ private static String imLocation;
+
+ private static void downloadUrl(String url, String filePath, int maxRetry) throws Exception{
+ File tmpFile = new File(filePath);
+ Boolean success = false;
+ if (!tmpFile.exists()) {
+ while (maxRetry > 0 && !success) {
+ try {
+ FileUtils.copyURLToFile(new URL(url), tmpFile);
+ success = true;
+ } catch(Exception e){
+ maxRetry -= 1;
+ }
+ }
+ } else {
+ success = true;
+ }
+ if (!success) throw new Exception("$url Download failed!");
+ }
+
+ @BeforeClass
+ public static void downloadFile() throws Exception {
+ String tempDirPath = System.getProperty("java.io.tmpdir");
+ imLocation = tempDirPath + "/inputImages/Pug-Cookie.jpg";
+ downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
+ imLocation, 3);
+ }
+
+ @Test
+ public void testImageProcess() {
+ NDArray nd = Image.imRead(imLocation, 1, true);
+ assertArrayEquals(nd.shape().toArray(), new int[]{576, 1024, 3});
+ NDArray nd2 = Image.imResize(nd, 224, 224, null);
+ assertArrayEquals(nd2.shape().toArray(), new int[]{224, 224, 3});
+ NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224);
+ Image.toImage(cropped);
+ }
+}
diff --git a/scala-package/examples/pom.xml b/scala-package/examples/pom.xml
index 564102a9f696..30ccfdcce12e 100644
--- a/scala-package/examples/pom.xml
+++ b/scala-package/examples/pom.xml
@@ -15,6 +15,7 @@
true
+ ${skipTests}
diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java
index c9b4426f52b3..c5d209998d32 100644
--- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java
+++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExample.java
@@ -24,8 +24,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import javax.imageio.ImageIO;
-import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.io.BufferedReader;
import java.io.File;
@@ -47,76 +45,7 @@ public class PredictorExample {
private String inputImagePath = "/images/dog.jpg";
final static Logger logger = LoggerFactory.getLogger(PredictorExample.class);
-
- /**
- * Load the image from file to buffered image
- * It can be replaced by loadImageFromFile from ObjectDetector
- * @param inputImagePath input image Path in String
- * @return Buffered image
- */
- private static BufferedImage loadIamgeFromFile(String inputImagePath) {
- BufferedImage buf = null;
- try {
- buf = ImageIO.read(new File(inputImagePath));
- } catch (IOException e) {
- System.err.println(e);
- }
- return buf;
- }
-
- /**
- * Reshape the current image using ImageIO and Graph2D
- * It can be replaced by reshapeImage from ObjectDetector
- * @param buf Buffered image
- * @param newWidth desired width
- * @param newHeight desired height
- * @return a reshaped bufferedImage
- */
- private static BufferedImage reshapeImage(BufferedImage buf, int newWidth, int newHeight) {
- BufferedImage resizedImage = new BufferedImage(newWidth, newHeight, BufferedImage.TYPE_INT_RGB);
- Graphics2D g = resizedImage.createGraphics();
- g.drawImage(buf, 0, 0, newWidth, newHeight, null);
- g.dispose();
- return resizedImage;
- }
-
- /**
- * Convert an image from a buffered image into pixels float array
- * It can be replaced by bufferedImageToPixels from ObjectDetector
- * @param buf buffered image
- * @return Float array
- */
- private static float[] imagePreprocess(BufferedImage buf) {
- // Get height and width of the image
- int w = buf.getWidth();
- int h = buf.getHeight();
-
- // get an array of integer pixels in the default RGB color mode
- int[] pixels = buf.getRGB(0, 0, w, h, null, 0, w);
-
- // 3 times height and width for R,G,B channels
- float[] result = new float[3 * h * w];
-
- int row = 0;
- // copy pixels to array vertically
- while (row < h) {
- int col = 0;
- // copy pixels to array horizontally
- while (col < w) {
- int rgb = pixels[row * w + col];
- // getting red color
- result[0 * h * w + row * w + col] = (rgb >> 16) & 0xFF;
- // getting green color
- result[1 * h * w + row * w + col] = (rgb >> 8) & 0xFF;
- // getting blue color
- result[2 * h * w + row * w + col] = rgb & 0xFF;
- col += 1;
- }
- row += 1;
- }
- buf.flush();
- return result;
- }
+ private static NDArray$ NDArray = NDArray$.MODULE$;
/**
* Helper class to print the maximum prediction result
@@ -170,11 +99,10 @@ public static void main(String[] args) {
inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
Predictor predictor = new Predictor(inst.modelPathPrefix, inputDesc, context,0);
// Prepare data
- BufferedImage img = loadIamgeFromFile(inst.inputImagePath);
-
- img = reshapeImage(img, 224, 224);
+ NDArray img = Image.imRead(inst.inputImagePath, 1, true);
+ img = Image.imResize(img, 224, 224, null);
// predict
- float[][] result = predictor.predict(new float[][]{imagePreprocess(img)});
+ float[][] result = predictor.predict(new float[][]{img.toArray()});
try {
System.out.println("Predict with Float input");
System.out.println(printMaximumClass(result[0], inst.modelPathPrefix));
@@ -182,10 +110,10 @@ public static void main(String[] args) {
System.err.println(e);
}
// predict with NDArray
- NDArray nd = new NDArray(
- imagePreprocess(img),
- new Shape(new int[]{1, 3, 224, 224}),
- Context.cpu());
+ NDArray nd = img;
+ nd = NDArray.transpose(nd, new Shape(new int[]{2, 0, 1}), null)[0];
+ nd = NDArray.expand_dims(nd, 0, null)[0];
+ nd = nd.asType(DType.Float32());
List ndList = new ArrayList<>();
ndList.add(nd);
List ndResult = predictor.predictWithNDArray(ndList);
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala
index c1ff10c6c8a2..dba343160bff 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala
@@ -24,9 +24,9 @@ import org.apache.commons.io.FileUtils
object Util {
- def downloadUrl(url: String, filePath: String, maxRetry: Option[Int] = None) : Unit = {
+ def downloadUrl(url: String, filePath: String, maxRetry: Int = 3) : Unit = {
val tmpFile = new File(filePath)
- var retry = maxRetry.getOrElse(3)
+ var retry = maxRetry
var success = false
if (!tmpFile.exists()) {
while (retry > 0 && !success) {
diff --git a/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java
new file mode 100644
index 000000000000..30bc8db447d8
--- /dev/null
+++ b/scala-package/examples/src/test/java/org/apache/mxnetexamples/javaapi/infer/predictor/PredictorExampleTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.javaapi.infer.predictor;
+
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.apache.mxnetexamples.Util;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+
+public class PredictorExampleTest {
+
+ final static Logger logger = LoggerFactory.getLogger(PredictorExampleTest.class);
+ private static String modelPathPrefix = "";
+ private static String inputImagePath = "";
+
+ @BeforeClass
+ public static void downloadFile() {
+ logger.info("Downloading resnet-18 model");
+
+ String tempDirPath = System.getProperty("java.io.tmpdir");
+ logger.info("tempDirPath: %s".format(tempDirPath));
+
+ String baseUrl = "https://s3.us-east-2.amazonaws.com/scala-infer-models";
+
+ Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-symbol.json",
+ tempDirPath + "/resnet18/resnet-18-symbol.json", 3);
+ Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-0000.params",
+ tempDirPath + "/resnet18/resnet-18-0000.params", 3);
+ Util.downloadUrl(baseUrl + "/resnet-18/synset.txt",
+ tempDirPath + "/resnet18/synset.txt", 3);
+ Util.downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
+ tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg", 3);
+
+ modelPathPrefix = tempDirPath + File.separator + "resnet18/resnet-18";
+ inputImagePath = tempDirPath + File.separator +
+ "inputImages/resnet18/Pug-Cookie.jpg";
+ }
+
+ @Test
+ public void testPredictor(){
+ PredictorExample example = new PredictorExample();
+ String[] args = new String[]{
+ "--model-path-prefix", modelPathPrefix,
+ "--input-image", inputImagePath
+ };
+ example.main(args);
+ }
+
+}
diff --git a/scala-package/infer/pom.xml b/scala-package/infer/pom.xml
index 13ceebb83cd9..565ac6e393a5 100644
--- a/scala-package/infer/pom.xml
+++ b/scala-package/infer/pom.xml
@@ -64,13 +64,5 @@
1.10.19
test
-
-
- junit
- junit
- 4.11
- test
-
-
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
index 4dfd6eb044a1..fa3565b4fb8e 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
@@ -96,9 +96,9 @@ private[mxnet] object JavaNDArrayMacro extends GeneratorBase {
// add default out parameter
argDef += s"out: org.apache.mxnet.javaapi.NDArray"
if (useParamObject) {
- impl += "if (po.getOut() != null) map(\"out\") = po.getOut()"
+ impl += "if (po.getOut() != null) map(\"out\") = po.getOut().nd"
} else {
- impl += "if (out != null) map(\"out\") = out"
+ impl += "if (out != null) map(\"out\") = out.nd"
}
val returnType = "Array[org.apache.mxnet.javaapi.NDArray]"
// scalastyle:off
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index 07baeab9b63f..5ba6f1f9f498 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -412,6 +412,12 @@
1.13.5
test
+
+ junit
+ junit
+ 4.11
+ test
+