diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala index 67809c158aff..50139ec1be22 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala @@ -242,6 +242,8 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) { this(NDArray.array(arr, shape, ctx)) } + override def toString: String = nd.toString + def serialize(): Array[Byte] = nd.serialize() /** diff --git a/scala-package/mxnet-demo/java-demo/README.md b/scala-package/mxnet-demo/java-demo/README.md index 760219343ed2..cad52cbc44cb 100644 --- a/scala-package/mxnet-demo/java-demo/README.md +++ b/scala-package/mxnet-demo/java-demo/README.md @@ -16,9 +16,15 @@ # MXNet Java Sample Project -This is an project created to use Maven-published Scala/Java package with two Java examples. +This is a project demonstrating how to use the Maven published Scala/Java MXNet package. +The examples provided include: +* NDArray creation +* NDArray operation +* Object Detection using the Inference API +* Image Classification using the Predictor API + ## Setup -You are required to use Maven to build the package with the following commands: +You are required to use Maven to build the package with the following commands under `java-demo`: ``` mvn package ``` @@ -42,16 +48,16 @@ The `SCALA_PKG_PROFILE` should be chosen from `osx-x86_64-cpu`, `linux-x86_64-cp ## Run -### Hello World -The Scala file is being executed using Java. You can execute the helloWorld example as follows: +### NDArrayCreation +The Scala file is being executed using Java. You can execute the `NDArrayCreation` example as follows: ```Bash bash bin/java_sample.sh ``` You can also run the following command manually: ```Bash -java -cp $CLASSPATH sample.HelloWorld +java -cp $CLASSPATH sample.NDArrayCreation ``` -However, you have to define the Classpath before you run the demo code. More information can be found in the `java_sample.sh`. +However, you have to define the Classpath before you run the demo code. More information can be found in `bin/java_sample.sh`. The `CLASSPATH` should point to the jar file you have downloaded. It will load the library automatically and run the example diff --git a/scala-package/mxnet-demo/java-demo/bin/java_sample.sh b/scala-package/mxnet-demo/java-demo/bin/java_sample.sh index 4fb724aca8db..fb1795f20f9d 100755 --- a/scala-package/mxnet-demo/java-demo/bin/java_sample.sh +++ b/scala-package/mxnet-demo/java-demo/bin/java_sample.sh @@ -17,4 +17,4 @@ #!/bin/bash CURR_DIR=$(cd $(dirname $0)/../; pwd) CLASSPATH=$CLASSPATH:$CURR_DIR/target/*:$CLASSPATH:$CURR_DIR/target/dependency/* -java -Xmx8G -cp $CLASSPATH mxnet.HelloWorld \ No newline at end of file +java -Xmx8G -cp $CLASSPATH mxnet.NDArrayCreation diff --git a/scala-package/mxnet-demo/java-demo/bin/run_od.sh b/scala-package/mxnet-demo/java-demo/bin/run_od.sh index abd0bf5b1b93..4370518dc8cd 100755 --- a/scala-package/mxnet-demo/java-demo/bin/run_od.sh +++ b/scala-package/mxnet-demo/java-demo/bin/run_od.sh @@ -17,4 +17,4 @@ #!/bin/bash CURR_DIR=$(cd $(dirname $0)/../; pwd) CLASSPATH=$CLASSPATH:$CURR_DIR/target/*:$CLASSPATH:$CURR_DIR/target/dependency/* -java -Xmx8G -cp $CLASSPATH mxnet.ObjectDetection \ No newline at end of file +java -Xmx8G -cp $CLASSPATH mxnet.ObjectDetection diff --git a/scala-package/mxnet-demo/java-demo/pom.xml b/scala-package/mxnet-demo/java-demo/pom.xml index 39253b1ce918..eb5e043a0dda 100644 --- a/scala-package/mxnet-demo/java-demo/pom.xml +++ b/scala-package/mxnet-demo/java-demo/pom.xml @@ -82,6 +82,12 @@ mxnet-full_${mxnet.scalaprofile}-${mxnet.profile} ${mxnet.version} + + org.apache.mxnet + mxnet-full_${mxnet.scalaprofile}-${mxnet.profile} + ${mxnet.version} + sources + commons-io commons-io diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ImageClassification.java b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ImageClassification.java new file mode 100644 index 000000000000..8cb58da5c2e6 --- /dev/null +++ b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ImageClassification.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mxnet; + +import org.apache.commons.io.FileUtils; +import org.apache.mxnet.infer.javaapi.Predictor; +import org.apache.mxnet.javaapi.*; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.net.URL; +import java.util.ArrayList; +import java.util.List; + +public class ImageClassification { + private static String modelPath; + private static String imagePath; + + private static void downloadUrl(String url, String filePath) { + File tmpFile = new File(filePath); + if (!tmpFile.exists()) { + try { + FileUtils.copyURLToFile(new URL(url), tmpFile); + } catch (Exception exception) { + System.err.println(exception); + } + } + } + + public static void downloadModelImage() { + String tempDirPath = System.getProperty("java.io.tmpdir"); + String baseUrl = "https://s3.us-east-2.amazonaws.com/scala-infer-models"; + downloadUrl(baseUrl + "/resnet-18/resnet-18-symbol.json", + tempDirPath + "/resnet18/resnet-18-symbol.json"); + downloadUrl(baseUrl + "/resnet-18/resnet-18-0000.params", + tempDirPath + "/resnet18/resnet-18-0000.params"); + downloadUrl(baseUrl + "/resnet-18/synset.txt", + tempDirPath + "/resnet18/synset.txt"); + downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg", + tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg"); + modelPath = tempDirPath + File.separator + "resnet18/resnet-18"; + imagePath = tempDirPath + File.separator + + "inputImages/resnet18/Pug-Cookie.jpg"; + } + + /** + * Helper class to print the maximum prediction result + * @param probabilities The float array of probability + * @param modelPathPrefix model Path needs to load the synset.txt + */ + private static String printMaximumClass(float[] probabilities, + String modelPathPrefix) throws IOException { + String synsetFilePath = modelPathPrefix.substring(0, + 1 + modelPathPrefix.lastIndexOf(File.separator)) + "/synset.txt"; + BufferedReader reader = new BufferedReader(new FileReader(synsetFilePath)); + ArrayList list = new ArrayList<>(); + String line = reader.readLine(); + + while (line != null){ + list.add(line); + line = reader.readLine(); + } + reader.close(); + + int maxIdx = 0; + for (int i = 1;i probabilities[maxIdx]) { + maxIdx = i; + } + } + + return "Probability : " + probabilities[maxIdx] + " Class : " + list.get(maxIdx) ; + } + + public static void main(String[] args) { + // Download the model and Image + downloadModelImage(); + + // Prepare the model + List context = new ArrayList(); + context.add(Context.cpu()); + List inputDesc = new ArrayList<>(); + Shape inputShape = new Shape(new int[]{1, 3, 224, 224}); + inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW")); + Predictor predictor = new Predictor(modelPath, inputDesc, context,0); + + // Prepare data + NDArray nd = Image.imRead(imagePath, 1, true); + nd = Image.imResize(nd, 224, 224, null); + nd = NDArray.transpose(nd, new Shape(new int[]{2, 0, 1}), null)[0]; // HWC to CHW + nd = NDArray.expand_dims(nd, 0, null)[0]; // Add N -> NCHW + nd = nd.asType(DType.Float32()); // Inference with Float32 + + // Predict directly + float[][] result = predictor.predict(new float[][]{nd.toArray()}); + try { + System.out.println("Predict with Float input"); + System.out.println(printMaximumClass(result[0], modelPath)); + } catch (IOException e) { + System.err.println(e); + } + + // predict with NDArray + List ndList = new ArrayList<>(); + ndList.add(nd); + List ndResult = predictor.predictWithNDArray(ndList); + try { + System.out.println("Predict with NDArray"); + System.out.println(printMaximumClass(ndResult.get(0).toArray(), modelPath)); + } catch (IOException e) { + System.err.println(e); + } + } +} diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayCreation.java b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayCreation.java new file mode 100644 index 000000000000..32e2d84dcdbf --- /dev/null +++ b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayCreation.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package mxnet; + +import org.apache.mxnet.javaapi.*; + +public class NDArrayCreation { + static NDArray$ NDArray = NDArray$.MODULE$; + public static void main(String[] args) { + + // Create new NDArray + NDArray nd = new NDArray(new float[]{2.0f, 3.0f}, new Shape(new int[]{1, 2}), Context.cpu()); + System.out.println(nd); + + // create new Double NDArray + NDArray ndDouble = new NDArray(new double[]{2.0d, 3.0d}, new Shape(new int[]{2, 1}), Context.cpu()); + System.out.println(ndDouble); + + // create ones + NDArray ones = NDArray.ones(Context.cpu(), new int[] {1, 2, 3}); + System.out.println(ones); + + // random + NDArray random = NDArray.random_uniform( + NDArray.new random_uniformParam() + .setLow(0.0f) + .setHigh(2.0f) + .setShape(new Shape(new int[]{10, 10})) + )[0]; + System.out.println(random); + } +} diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/HelloWorld.java b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayOperation.java similarity index 67% rename from scala-package/mxnet-demo/java-demo/src/main/java/mxnet/HelloWorld.java rename to scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayOperation.java index 71981e2691c5..56a414307f46 100644 --- a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/HelloWorld.java +++ b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/NDArrayOperation.java @@ -14,19 +14,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package mxnet; import org.apache.mxnet.javaapi.*; -import java.util.Arrays; -public class HelloWorld { +public class NDArrayOperation { static NDArray$ NDArray = NDArray$.MODULE$; - public static void main(String[] args) { - System.out.println("Hello World!"); NDArray nd = new NDArray(new float[]{2.0f, 3.0f}, new Shape(new int[]{1, 2}), Context.cpu()); - System.out.println(nd.shape()); - NDArray nd2 = NDArray.dot(NDArray.new dotParam(nd, nd.T()))[0]; - System.out.println(Arrays.toString(nd2.toArray())); + + // Transpose + NDArray ndT = nd.T(); + System.out.println(nd); + System.out.println(ndT); + + // change Data Type + NDArray ndInt = nd.asType(DType.Int32()); + System.out.println(ndInt); + + // element add + NDArray eleAdd = NDArray.elemwise_add(nd, nd, null)[0]; + System.out.println(eleAdd); + + // norm (L2 Norm) + NDArray normed = NDArray.norm(NDArray.new normParam(nd))[0]; + System.out.println(normed); } } diff --git a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ObjectDetection.java b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ObjectDetection.java index cfe9b66c4b3f..65fe286aa2c7 100644 --- a/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ObjectDetection.java +++ b/scala-package/mxnet-demo/java-demo/src/main/java/mxnet/ObjectDetection.java @@ -68,20 +68,18 @@ public static void downloadModelImage() { public static void main(String[] args) { List context = new ArrayList(); - if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && - Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) { - context.add(Context.gpu()); - } else { - context.add(Context.cpu()); - } + context.add(Context.cpu()); downloadModelImage(); + + List> output + = runObjectDetectionSingle(modelPath, imagePath, context); + Shape inputShape = new Shape(new int[] {1, 3, 512, 512}); Shape outputShape = new Shape(new int[] {1, 6132, 6}); int width = inputShape.get(2); int height = inputShape.get(3); - List> output - = runObjectDetectionSingle(modelPath, imagePath, context); String outputStr = "\n"; + for (List ele : output) { for (ObjectDetectorOutput i : ele) { outputStr += "Class: " + i.getClassName() + "\n"; @@ -98,4 +96,4 @@ public static void main(String[] args) { } System.out.println(outputStr); } -} \ No newline at end of file +}