diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
index 67809c158aff..50139ec1be22 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
@@ -242,6 +242,8 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
this(NDArray.array(arr, shape, ctx))
}
+ override def toString: String = nd.toString
+
def serialize(): Array[Byte] = nd.serialize()
/**
diff --git a/scala-package/mxnet-demo/java-demo/README.md b/scala-package/mxnet-demo/java-demo/README.md
index 760219343ed2..cad52cbc44cb 100644
--- a/scala-package/mxnet-demo/java-demo/README.md
+++ b/scala-package/mxnet-demo/java-demo/README.md
@@ -16,9 +16,15 @@
# MXNet Java Sample Project
-This is an project created to use Maven-published Scala/Java package with two Java examples.
+This is a project demonstrating how to use the Maven published Scala/Java MXNet package.
+The examples provided include:
+* NDArray creation
+* NDArray operation
+* Object Detection using the Inference API
+* Image Classification using the Predictor API
+
## Setup
-You are required to use Maven to build the package with the following commands:
+You are required to use Maven to build the package with the following commands under `java-demo`:
```
mvn package
```
@@ -42,16 +48,16 @@ The `SCALA_PKG_PROFILE` should be chosen from `osx-x86_64-cpu`, `linux-x86_64-cp
## Run
-### Hello World
-The Scala file is being executed using Java. You can execute the helloWorld example as follows:
+### NDArrayCreation
+The Scala file is being executed using Java. You can execute the `NDArrayCreation` example as follows:
```Bash
bash bin/java_sample.sh
```
You can also run the following command manually:
```Bash
-java -cp $CLASSPATH sample.HelloWorld
+java -cp $CLASSPATH sample.NDArrayCreation
```
-However, you have to define the Classpath before you run the demo code. More information can be found in the `java_sample.sh`.
+However, you have to define the Classpath before you run the demo code. More information can be found in `bin/java_sample.sh`.
The `CLASSPATH` should point to the jar file you have downloaded.
It will load the library automatically and run the example
diff --git a/scala-package/mxnet-demo/java-demo/bin/java_sample.sh b/scala-package/mxnet-demo/java-demo/bin/java_sample.sh
index 4fb724aca8db..fb1795f20f9d 100755
--- a/scala-package/mxnet-demo/java-demo/bin/java_sample.sh
+++ b/scala-package/mxnet-demo/java-demo/bin/java_sample.sh
@@ -17,4 +17,4 @@
#!/bin/bash
CURR_DIR=$(cd $(dirname $0)/../; pwd)
CLASSPATH=$CLASSPATH:$CURR_DIR/target/*:$CLASSPATH:$CURR_DIR/target/dependency/*
-java -Xmx8G -cp $CLASSPATH mxnet.HelloWorld
\ No newline at end of file
+java -Xmx8G -cp $CLASSPATH mxnet.NDArrayCreation
diff --git a/scala-package/mxnet-demo/java-demo/bin/run_od.sh b/scala-package/mxnet-demo/java-demo/bin/run_od.sh
index abd0bf5b1b93..4370518dc8cd 100755
--- a/scala-package/mxnet-demo/java-demo/bin/run_od.sh
+++ b/scala-package/mxnet-demo/java-demo/bin/run_od.sh
@@ -17,4 +17,4 @@
#!/bin/bash
CURR_DIR=$(cd $(dirname $0)/../; pwd)
CLASSPATH=$CLASSPATH:$CURR_DIR/target/*:$CLASSPATH:$CURR_DIR/target/dependency/*
-java -Xmx8G -cp $CLASSPATH mxnet.ObjectDetection
\ No newline at end of file
+java -Xmx8G -cp $CLASSPATH mxnet.ObjectDetection
diff --git a/scala-package/mxnet-demo/java-demo/pom.xml b/scala-package/mxnet-demo/java-demo/pom.xml
index 39253b1ce918..eb5e043a0dda 100644
--- a/scala-package/mxnet-demo/java-demo/pom.xml
+++ b/scala-package/mxnet-demo/java-demo/pom.xml
@@ -82,6 +82,12 @@
mxnet-full_${mxnet.scalaprofile}-${mxnet.profile}
${mxnet.version}
+
+ org.apache.mxnet
+ mxnet-full_${mxnet.scalaprofile}-${mxnet.profile}
+ ${mxnet.version}
+ sources
+
commons-io
commons-io
diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ImageClassification.java b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ImageClassification.java
new file mode 100644
index 000000000000..8cb58da5c2e6
--- /dev/null
+++ b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ImageClassification.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package mxnet;
+
+import org.apache.commons.io.FileUtils;
+import org.apache.mxnet.infer.javaapi.Predictor;
+import org.apache.mxnet.javaapi.*;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileReader;
+import java.io.IOException;
+import java.net.URL;
+import java.util.ArrayList;
+import java.util.List;
+
+public class ImageClassification {
+ private static String modelPath;
+ private static String imagePath;
+
+ private static void downloadUrl(String url, String filePath) {
+ File tmpFile = new File(filePath);
+ if (!tmpFile.exists()) {
+ try {
+ FileUtils.copyURLToFile(new URL(url), tmpFile);
+ } catch (Exception exception) {
+ System.err.println(exception);
+ }
+ }
+ }
+
+ public static void downloadModelImage() {
+ String tempDirPath = System.getProperty("java.io.tmpdir");
+ String baseUrl = "https://s3.us-east-2.amazonaws.com/scala-infer-models";
+ downloadUrl(baseUrl + "/resnet-18/resnet-18-symbol.json",
+ tempDirPath + "/resnet18/resnet-18-symbol.json");
+ downloadUrl(baseUrl + "/resnet-18/resnet-18-0000.params",
+ tempDirPath + "/resnet18/resnet-18-0000.params");
+ downloadUrl(baseUrl + "/resnet-18/synset.txt",
+ tempDirPath + "/resnet18/synset.txt");
+ downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
+ tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg");
+ modelPath = tempDirPath + File.separator + "resnet18/resnet-18";
+ imagePath = tempDirPath + File.separator +
+ "inputImages/resnet18/Pug-Cookie.jpg";
+ }
+
+ /**
+ * Helper class to print the maximum prediction result
+ * @param probabilities The float array of probability
+ * @param modelPathPrefix model Path needs to load the synset.txt
+ */
+ private static String printMaximumClass(float[] probabilities,
+ String modelPathPrefix) throws IOException {
+ String synsetFilePath = modelPathPrefix.substring(0,
+ 1 + modelPathPrefix.lastIndexOf(File.separator)) + "/synset.txt";
+ BufferedReader reader = new BufferedReader(new FileReader(synsetFilePath));
+ ArrayList list = new ArrayList<>();
+ String line = reader.readLine();
+
+ while (line != null){
+ list.add(line);
+ line = reader.readLine();
+ }
+ reader.close();
+
+ int maxIdx = 0;
+ for (int i = 1;i probabilities[maxIdx]) {
+ maxIdx = i;
+ }
+ }
+
+ return "Probability : " + probabilities[maxIdx] + " Class : " + list.get(maxIdx) ;
+ }
+
+ public static void main(String[] args) {
+ // Download the model and Image
+ downloadModelImage();
+
+ // Prepare the model
+ List context = new ArrayList();
+ context.add(Context.cpu());
+ List inputDesc = new ArrayList<>();
+ Shape inputShape = new Shape(new int[]{1, 3, 224, 224});
+ inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
+ Predictor predictor = new Predictor(modelPath, inputDesc, context,0);
+
+ // Prepare data
+ NDArray nd = Image.imRead(imagePath, 1, true);
+ nd = Image.imResize(nd, 224, 224, null);
+ nd = NDArray.transpose(nd, new Shape(new int[]{2, 0, 1}), null)[0]; // HWC to CHW
+ nd = NDArray.expand_dims(nd, 0, null)[0]; // Add N -> NCHW
+ nd = nd.asType(DType.Float32()); // Inference with Float32
+
+ // Predict directly
+ float[][] result = predictor.predict(new float[][]{nd.toArray()});
+ try {
+ System.out.println("Predict with Float input");
+ System.out.println(printMaximumClass(result[0], modelPath));
+ } catch (IOException e) {
+ System.err.println(e);
+ }
+
+ // predict with NDArray
+ List ndList = new ArrayList<>();
+ ndList.add(nd);
+ List ndResult = predictor.predictWithNDArray(ndList);
+ try {
+ System.out.println("Predict with NDArray");
+ System.out.println(printMaximumClass(ndResult.get(0).toArray(), modelPath));
+ } catch (IOException e) {
+ System.err.println(e);
+ }
+ }
+}
diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayCreation.java b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayCreation.java
new file mode 100644
index 000000000000..32e2d84dcdbf
--- /dev/null
+++ b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayCreation.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package mxnet;
+
+import org.apache.mxnet.javaapi.*;
+
+public class NDArrayCreation {
+ static NDArray$ NDArray = NDArray$.MODULE$;
+ public static void main(String[] args) {
+
+ // Create new NDArray
+ NDArray nd = new NDArray(new float[]{2.0f, 3.0f}, new Shape(new int[]{1, 2}), Context.cpu());
+ System.out.println(nd);
+
+ // create new Double NDArray
+ NDArray ndDouble = new NDArray(new double[]{2.0d, 3.0d}, new Shape(new int[]{2, 1}), Context.cpu());
+ System.out.println(ndDouble);
+
+ // create ones
+ NDArray ones = NDArray.ones(Context.cpu(), new int[] {1, 2, 3});
+ System.out.println(ones);
+
+ // random
+ NDArray random = NDArray.random_uniform(
+ NDArray.new random_uniformParam()
+ .setLow(0.0f)
+ .setHigh(2.0f)
+ .setShape(new Shape(new int[]{10, 10}))
+ )[0];
+ System.out.println(random);
+ }
+}
diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/HelloWorld.java b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayOperation.java
similarity index 67%
rename from scala-package/mxnet-demo/java-demo/src/main/java/mxnet/HelloWorld.java
rename to scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayOperation.java
index 71981e2691c5..56a414307f46 100644
--- a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/HelloWorld.java
+++ b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayOperation.java
@@ -14,19 +14,31 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package mxnet;
import org.apache.mxnet.javaapi.*;
-import java.util.Arrays;
-public class HelloWorld {
+public class NDArrayOperation {
static NDArray$ NDArray = NDArray$.MODULE$;
-
public static void main(String[] args) {
- System.out.println("Hello World!");
NDArray nd = new NDArray(new float[]{2.0f, 3.0f}, new Shape(new int[]{1, 2}), Context.cpu());
- System.out.println(nd.shape());
- NDArray nd2 = NDArray.dot(NDArray.new dotParam(nd, nd.T()))[0];
- System.out.println(Arrays.toString(nd2.toArray()));
+
+ // Transpose
+ NDArray ndT = nd.T();
+ System.out.println(nd);
+ System.out.println(ndT);
+
+ // change Data Type
+ NDArray ndInt = nd.asType(DType.Int32());
+ System.out.println(ndInt);
+
+ // element add
+ NDArray eleAdd = NDArray.elemwise_add(nd, nd, null)[0];
+ System.out.println(eleAdd);
+
+ // norm (L2 Norm)
+ NDArray normed = NDArray.norm(NDArray.new normParam(nd))[0];
+ System.out.println(normed);
}
}
diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ObjectDetection.java b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ObjectDetection.java
index cfe9b66c4b3f..65fe286aa2c7 100644
--- a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ObjectDetection.java
+++ b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ObjectDetection.java
@@ -68,20 +68,18 @@ public static void downloadModelImage() {
public static void main(String[] args) {
List context = new ArrayList();
- if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
- Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
- context.add(Context.gpu());
- } else {
- context.add(Context.cpu());
- }
+ context.add(Context.cpu());
downloadModelImage();
+
+ List> output
+ = runObjectDetectionSingle(modelPath, imagePath, context);
+
Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
Shape outputShape = new Shape(new int[] {1, 6132, 6});
int width = inputShape.get(2);
int height = inputShape.get(3);
- List> output
- = runObjectDetectionSingle(modelPath, imagePath, context);
String outputStr = "\n";
+
for (List ele : output) {
for (ObjectDetectorOutput i : ele) {
outputStr += "Class: " + i.getClassName() + "\n";
@@ -98,4 +96,4 @@ public static void main(String[] args) {
}
System.out.println(outputStr);
}
-}
\ No newline at end of file
+}