diff --git a/scala-package/.gitignore b/scala-package/.gitignore index 9bf7851716d6..dadc000c612e 100644 --- a/scala-package/.gitignore +++ b/scala-package/.gitignore @@ -9,3 +9,4 @@ core/src/main/scala/org/apache/mxnet/SymbolBase.scala core/src/main/scala/org/apache/mxnet/SymbolRandomAPIBase.scala examples/scripts/infer/images/ examples/scripts/infer/models/ +examples/scripts/infer/objectdetector/boundingImage.png diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala index 0f756e24027f..52e26efb41f1 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala @@ -17,6 +17,7 @@ package org.apache.mxnet // scalastyle:off +import java.awt.{BasicStroke, Color, Graphics2D} import java.awt.image.BufferedImage // scalastyle:on import java.io.InputStream @@ -182,4 +183,57 @@ object Image { img } + /** + * Helper function to generate ramdom colors + * @param transparency The transparency level + * @return Color + */ + private def randomColor(transparency: Option[Float] = Some(1.0f)) : Color = { + new Color( + Math.random().toFloat, Math.random().toFloat, Math.random().toFloat, + transparency.get + ) + } + + /** + * Method to draw bounding boxes for an image + * @param src Source of the buffered image + * @param coordinate Contains Map of xmin, xmax, ymin, ymax + * corresponding to top-left and down-right points + * @param names The name set of the bounding box + * @param stroke Thickness of the bounding box + * @param fontSizeMult Font size multiplier + * @param transparency Transparency of the bounding box + */ + def drawBoundingBox(src: BufferedImage, coordinate: Array[Map[String, Int]], + names: Option[Array[String]] = None, + stroke : Option[Int] = Some(3), + fontSizeMult : Option[Float] = Some(1.0f), + transparency: Option[Float] = Some(1.0f)): Unit = { + val g2d : Graphics2D = src.createGraphics() + g2d.setStroke(new BasicStroke(stroke.get)) + // Increase the size of font + val currentFont = g2d.getFont + val newFont = currentFont.deriveFont(currentFont.getSize * fontSizeMult.get) + g2d.setFont(newFont) + // Get font metrics to draw the font box + val fm = g2d.getFontMetrics(newFont) + for (idx <- coordinate.indices) { + val map = coordinate(idx) + g2d.setColor(randomColor(transparency).darker()) + g2d.drawRect(map("xmin"), map("ymin"), map("xmax") - map("xmin"), map("ymax") - map("ymin")) + // Write the name of the bounding box + if (names.isDefined) { + val x = map("xmin") - stroke.get + val y = map("ymin") + val h = fm.getHeight + val w = fm.charsWidth(names.get(idx).toCharArray, 0, names.get(idx).length()) + g2d.fillRect(x, y - h, w, h) + g2d.setColor(Color.WHITE) + g2d.drawString(names.get(idx), x, y) + } + } + g2d.dispose() + } + } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala index 7d6f31e930ad..f72223d1e4da 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Image.scala @@ -20,15 +20,16 @@ package org.apache.mxnet.javaapi import java.awt.image.BufferedImage // scalastyle:on import java.io.InputStream +import scala.collection.JavaConverters._ object Image { /** * Decode image with OpenCV. * Note: return image in RGB by default, instead of OpenCV's default BGR. - * @param buf Buffer containing binary encoded image - * @param flag Convert decoded image to grayscale (0) or color (1). + * @param buf Buffer containing binary encoded image + * @param flag Convert decoded image to grayscale (0) or color (1). * @param toRGB Whether to convert decoded image - * to mxnet's default RGB format (instead of opencv's default BGR). + * to mxnet's default RGB format (instead of opencv's default BGR). * @return NDArray in HWC format with DType [[DType.UInt8]] */ def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = { @@ -43,8 +44,8 @@ object Image { * Same imageDecode with InputStream * * @param inputStream the inputStream of the image - * @param flag Convert decoded image to grayscale (0) or color (1). - * @param toRGB Whether to convert decoded image + * @param flag Convert decoded image to grayscale (0) or color (1). + * @param toRGB Whether to convert decoded image * @return NDArray in HWC format with DType [[DType.UInt8]] */ def imDecode(inputStream: InputStream, flag: Int, toRGB: Boolean): NDArray = { @@ -60,7 +61,7 @@ object Image { * Note: return image in RGB by default, instead of OpenCV's default BGR. * @param filename Name of the image file to be loaded. * @param flag Convert decoded image to grayscale (0) or color (1). - * @param toRGB Whether to convert decoded image to mxnet's default RGB format + * @param toRGB Whether to convert decoded image to mxnet's default RGB format * (instead of opencv's default BGR). * @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]] */ @@ -74,10 +75,10 @@ object Image { /** * Resize image with OpenCV. - * @param src source image in NDArray - * @param w Width of resized image. - * @param h Height of resized image. - * @param interp Interpolation method (default=cv2.INTER_LINEAR). + * @param src source image in NDArray + * @param w Width of resized image. + * @param h Height of resized image. + * @param interp Interpolation method (default=cv2.INTER_LINEAR). * @return org.apache.mxnet.NDArray */ def imResize(src: NDArray, w: Int, h: Int, interp: Integer): NDArray = { @@ -92,10 +93,10 @@ object Image { /** * Do a fixed crop on the image * @param src Src image in NDArray - * @param x0 starting x point - * @param y0 starting y point - * @param w width of the image - * @param h height of the image + * @param x0 starting x point + * @param y0 starting y point + * @param w width of the image + * @param h height of the image * @return cropped NDArray */ def fixedCrop(src: NDArray, x0: Int, y0: Int, w: Int, h: Int): NDArray = { @@ -111,4 +112,21 @@ object Image { def toImage(src: NDArray): BufferedImage = { org.apache.mxnet.Image.toImage(src) } + + /** + * Draw bounding boxes on the image + * @param src buffered image to draw on + * @param coordinate Contains Map of xmin, xmax, ymin, ymax + * corresponding to top-left and down-right points + * @param names The name set of the bounding box + */ + def drawBoundingBox(src: BufferedImage, + coordinate: java.util.List[ + java.util.Map[java.lang.String, java.lang.Integer]], + names: java.util.List[java.lang.String]): Unit = { + val coord = coordinate.asScala.map( + _.asScala.map{case (name, value) => (name, Integer2int(value))}.toMap).toArray + org.apache.mxnet.Image.drawBoundingBox(src, coord, Option(names.asScala.toArray)) + } + } diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java index 0092744a21a8..f5515dc053a8 100644 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/ImageTest.java @@ -20,8 +20,15 @@ import org.apache.commons.io.FileUtils; import org.junit.BeforeClass; import org.junit.Test; + +import javax.imageio.ImageIO; +import java.awt.image.BufferedImage; import java.io.File; import java.net.URL; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.junit.Assert.assertArrayEquals; @@ -56,12 +63,23 @@ public static void downloadFile() throws Exception { } @Test - public void testImageProcess() { + public void testImageProcess() throws Exception { NDArray nd = Image.imRead(imLocation, 1, true); assertArrayEquals(nd.shape().toArray(), new int[]{576, 1024, 3}); NDArray nd2 = Image.imResize(nd, 224, 224, null); assertArrayEquals(nd2.shape().toArray(), new int[]{224, 224, 3}); NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224); Image.toImage(cropped); + BufferedImage buf = ImageIO.read(new File(imLocation)); + Map map = new HashMap<>(); + map.put("xmin", 190); + map.put("xmax", 850); + map.put("ymin", 50); + map.put("ymax", 450); + List> box = new ArrayList<>(); + box.add(map); + List names = new ArrayList<>(); + names.add("pug"); + Image.drawBoundingBox(buf, box, names); } } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala index 67815ad6c108..d4cf35af186f 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala @@ -97,4 +97,25 @@ class ImageSuite extends FunSuite with BeforeAndAfterAll { logger.info(s"converted image stored in ${tempDirPath + "/inputImages/out.png"}") } + test("Test draw Bounding box") { + val buf = ImageIO.read(new File(imLocation)) + val box = Array( + Map("xmin" -> 190, "xmax" -> 850, "ymin" -> 50, "ymax" -> 450), + Map("xmin" -> 200, "xmax" -> 350, "ymin" -> 440, "ymax" -> 530) + ) + val names = Array("pug", "cookie") + Image.drawBoundingBox(buf, box, Some(names), fontSizeMult = Some(1.4f)) + val tempDirPath = System.getProperty("java.io.tmpdir") + ImageIO.write(buf, "png", new File(tempDirPath + "/inputImages/out2.png")) + logger.info(s"converted image stored in ${tempDirPath + "/inputImages/out2.png"}") + for (coord <- box) { + val topLeft = buf.getRGB(coord("xmin"), coord("ymin")) + val downLeft = buf.getRGB(coord("xmin"), coord("ymax")) + val topRight = buf.getRGB(coord("xmax"), coord("ymin")) + val downRight = buf.getRGB(coord("xmax"), coord("ymax")) + require(downLeft == downRight) + require(topRight == downRight) + } + } + } diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md index 8a9ed3e1736b..4c4512f152c8 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/README.md @@ -84,7 +84,7 @@ After the previous steps, you should be able to run the code using the following From the `scala-package/examples/scripts/infer/objectdetector/` folder run: ```bash -./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images +./run_ssd_java_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images ``` **Notes**: diff --git a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java index a9c00f7f1d81..31b8514de345 100644 --- a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java +++ b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer/objectdetector/SSDClassifierExample.java @@ -28,12 +28,11 @@ import org.apache.mxnet.infer.javaapi.ObjectDetector; // scalastyle:off +import javax.imageio.ImageIO; import java.awt.image.BufferedImage; // scalastyle:on -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; +import java.util.*; import java.io.File; @@ -128,22 +127,34 @@ public static void main(String[] args) { try { Shape inputShape = new Shape(new int[]{1, 3, 512, 512}); Shape outputShape = new Shape(new int[]{1, 6132, 6}); - - - int width = inputShape.get(2); - int height = inputShape.get(3); + StringBuilder outputStr = new StringBuilder().append("\n"); List> output = runObjectDetectionSingle(mdprefixDir, imgPath, context); - + + // Creating Bounding box material + BufferedImage buf = ImageIO.read(new File(imgPath)); + int width = buf.getWidth(); + int height = buf.getHeight(); + List> boxes = new ArrayList<>(); + List names = new ArrayList<>(); for (List ele : output) { for (ObjectDetectorOutput i : ele) { outputStr.append("Class: " + i.getClassName() + "\n"); outputStr.append("Probabilties: " + i.getProbability() + "\n"); - - List coord = Arrays.asList(i.getXMin() * width, - i.getXMax() * height, i.getYMin() * width, i.getYMax() * height); + names.add(i.getClassName()); + Map map = new HashMap<>(); + float xmin = i.getXMin() * width; + float xmax = i.getXMax() * width; + float ymin = i.getYMin() * height; + float ymax = i.getYMax() * height; + List coord = Arrays.asList(xmin, xmax, ymin, ymax); + map.put("xmin", (int) xmin); + map.put("xmax", (int) xmax); + map.put("ymin", (int) ymin); + map.put("ymax", (int) ymax); + boxes.add(map); StringBuilder sb = new StringBuilder(); for (float c : coord) { sb.append(", ").append(c); @@ -152,7 +163,12 @@ public static void main(String[] args) { } } logger.info(outputStr.toString()); - + + // Covert to image + Image.drawBoundingBox(buf, boxes, names); + File outputFile = new File("boundingImage.png"); + ImageIO.write(buf, "png", outputFile); + List>> outputList = runObjectDetectionBatch(mdprefixDir, imgDir, context); @@ -177,7 +193,6 @@ public static void main(String[] args) { } } logger.info(outputStr.toString()); - } catch (Exception e) { logger.error(e.getMessage(), e); parser.printUsage(System.err); diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala index e29f068d5558..28a578cae79f 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala @@ -132,7 +132,7 @@ class ObjectDetector(modelPathPrefix: String, if (topK.isDefined) { var sortedIndices = predictResult.zipWithIndex.sortBy(-_._1(1)).map(_._2) sortedIndices = sortedIndices.take(topK.get) - // takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax + // takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax] result = sortedIndices.map(idx => (synset(predictResult(idx)(0).toInt), predictResult(idx).takeRight(5))).toIndexedSeq diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala index 5a6ac7599fa9..32fd87e05f69 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/ObjectDetectorOutput.scala @@ -52,14 +52,14 @@ class ObjectDetectorOutput (className: String, args: Array[Float]){ * * @return Float of the max X coordinate for the object bounding box */ - def getXMax: Float = args(2) + def getXMax: Float = args(3) /** * Gets the minimum Y coordinate for the bounding box containing the predicted object. * * @return Float of the min Y coordinate for the object bounding box */ - def getYMin: Float = args(3) + def getYMin: Float = args(2) /** * Gets the maximum Y coordinate for the bounding box containing the predicted object. diff --git a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java index 04041fcda9bf..6f3df86b8e74 100644 --- a/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java +++ b/scala-package/infer/src/test/java/org/apache/mxnet/infer/javaapi/ObjectDetectorOutputTest.java @@ -36,8 +36,8 @@ public void testConstructor() { Assert.assertEquals(odOutput.getClassName(), predictedClassName); Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta); Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta); - Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 2f, delta); - Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 3f, delta); + Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 3f, delta); + Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 2f, delta); Assert.assertEquals("Threshold not matching", odOutput.getYMax(), 4f, delta); }