Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1285] Draw bounding box with Scala/Java Image API #14474

Merged
merged 5 commits into from
Mar 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scala-package/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ core/src/main/scala/org/apache/mxnet/SymbolBase.scala
core/src/main/scala/org/apache/mxnet/SymbolRandomAPIBase.scala
examples/scripts/infer/images/
examples/scripts/infer/models/
examples/scripts/infer/objectdetector/boundingImage.png
54 changes: 54 additions & 0 deletions scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.mxnet
// scalastyle:off
import java.awt.{BasicStroke, Color, Graphics2D}
import java.awt.image.BufferedImage
// scalastyle:on
import java.io.InputStream
Expand Down Expand Up @@ -182,4 +183,57 @@ object Image {
img
}

/**
* Helper function to generate ramdom colors
* @param transparency The transparency level
* @return Color
*/
private def randomColor(transparency: Option[Float] = Some(1.0f)) : Color = {
new Color(
Math.random().toFloat, Math.random().toFloat, Math.random().toFloat,
transparency.get
)
}

/**
* Method to draw bounding boxes for an image
* @param src Source of the buffered image
* @param coordinate Contains Map of xmin, xmax, ymin, ymax
* corresponding to top-left and down-right points
* @param names The name set of the bounding box
* @param stroke Thickness of the bounding box
* @param fontSizeMult Font size multiplier
* @param transparency Transparency of the bounding box
*/
def drawBoundingBox(src: BufferedImage, coordinate: Array[Map[String, Int]],
names: Option[Array[String]] = None,
stroke : Option[Int] = Some(3),
fontSizeMult : Option[Float] = Some(1.0f),
transparency: Option[Float] = Some(1.0f)): Unit = {
val g2d : Graphics2D = src.createGraphics()
g2d.setStroke(new BasicStroke(stroke.get))
// Increase the size of font
val currentFont = g2d.getFont
val newFont = currentFont.deriveFont(currentFont.getSize * fontSizeMult.get)
g2d.setFont(newFont)
// Get font metrics to draw the font box
val fm = g2d.getFontMetrics(newFont)
for (idx <- coordinate.indices) {
val map = coordinate(idx)
g2d.setColor(randomColor(transparency).darker())
g2d.drawRect(map("xmin"), map("ymin"), map("xmax") - map("xmin"), map("ymax") - map("ymin"))
// Write the name of the bounding box
if (names.isDefined) {
val x = map("xmin") - stroke.get
val y = map("ymin")
val h = fm.getHeight
val w = fm.charsWidth(names.get(idx).toCharArray, 0, names.get(idx).length())
g2d.fillRect(x, y - h, w, h)
g2d.setColor(Color.WHITE)
g2d.drawString(names.get(idx), x, y)
lanking520 marked this conversation as resolved.
Show resolved Hide resolved
}
}
g2d.dispose()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ package org.apache.mxnet.javaapi
import java.awt.image.BufferedImage
// scalastyle:on
import java.io.InputStream
import scala.collection.JavaConverters._

object Image {
/**
* Decode image with OpenCV.
* Note: return image in RGB by default, instead of OpenCV's default BGR.
* @param buf Buffer containing binary encoded image
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param buf Buffer containing binary encoded image
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image
* to mxnet's default RGB format (instead of opencv's default BGR).
* to mxnet's default RGB format (instead of opencv's default BGR).
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = {
Expand All @@ -43,8 +44,8 @@ object Image {
* Same imageDecode with InputStream
*
* @param inputStream the inputStream of the image
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image
* @return NDArray in HWC format with DType [[DType.UInt8]]
*/
def imDecode(inputStream: InputStream, flag: Int, toRGB: Boolean): NDArray = {
Expand All @@ -60,7 +61,7 @@ object Image {
* Note: return image in RGB by default, instead of OpenCV's default BGR.
* @param filename Name of the image file to be loaded.
* @param flag Convert decoded image to grayscale (0) or color (1).
* @param toRGB Whether to convert decoded image to mxnet's default RGB format
* @param toRGB Whether to convert decoded image to mxnet's default RGB format
* (instead of opencv's default BGR).
* @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
*/
Expand All @@ -74,10 +75,10 @@ object Image {

/**
* Resize image with OpenCV.
* @param src source image in NDArray
* @param w Width of resized image.
* @param h Height of resized image.
* @param interp Interpolation method (default=cv2.INTER_LINEAR).
* @param src source image in NDArray
* @param w Width of resized image.
* @param h Height of resized image.
* @param interp Interpolation method (default=cv2.INTER_LINEAR).
* @return org.apache.mxnet.NDArray
*/
def imResize(src: NDArray, w: Int, h: Int, interp: Integer): NDArray = {
Expand All @@ -92,10 +93,10 @@ object Image {
/**
* Do a fixed crop on the image
* @param src Src image in NDArray
* @param x0 starting x point
* @param y0 starting y point
* @param w width of the image
* @param h height of the image
* @param x0 starting x point
* @param y0 starting y point
* @param w width of the image
* @param h height of the image
* @return cropped NDArray
*/
def fixedCrop(src: NDArray, x0: Int, y0: Int, w: Int, h: Int): NDArray = {
Expand All @@ -111,4 +112,21 @@ object Image {
def toImage(src: NDArray): BufferedImage = {
org.apache.mxnet.Image.toImage(src)
}

/**
* Draw bounding boxes on the image
* @param src buffered image to draw on
* @param coordinate Contains Map of xmin, xmax, ymin, ymax
* corresponding to top-left and down-right points
* @param names The name set of the bounding box
*/
def drawBoundingBox(src: BufferedImage,
coordinate: java.util.List[
java.util.Map[java.lang.String, java.lang.Integer]],
names: java.util.List[java.lang.String]): Unit = {
val coord = coordinate.asScala.map(
_.asScala.map{case (name, value) => (name, Integer2int(value))}.toMap).toArray
org.apache.mxnet.Image.drawBoundingBox(src, coord, Option(names.asScala.toArray))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@
import org.apache.commons.io.FileUtils;
import org.junit.BeforeClass;
import org.junit.Test;

import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertArrayEquals;

Expand Down Expand Up @@ -56,12 +63,23 @@ public static void downloadFile() throws Exception {
}

@Test
public void testImageProcess() {
public void testImageProcess() throws Exception {
NDArray nd = Image.imRead(imLocation, 1, true);
assertArrayEquals(nd.shape().toArray(), new int[]{576, 1024, 3});
NDArray nd2 = Image.imResize(nd, 224, 224, null);
assertArrayEquals(nd2.shape().toArray(), new int[]{224, 224, 3});
NDArray cropped = Image.fixedCrop(nd, 0, 0, 224, 224);
Image.toImage(cropped);
BufferedImage buf = ImageIO.read(new File(imLocation));
Map<String, Integer> map = new HashMap<>();
map.put("xmin", 190);
map.put("xmax", 850);
map.put("ymin", 50);
map.put("ymax", 450);
List<Map<String, Integer>> box = new ArrayList<>();
box.add(map);
List<String> names = new ArrayList<>();
names.add("pug");
Image.drawBoundingBox(buf, box, names);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,25 @@ class ImageSuite extends FunSuite with BeforeAndAfterAll {
logger.info(s"converted image stored in ${tempDirPath + "/inputImages/out.png"}")
}

test("Test draw Bounding box") {
val buf = ImageIO.read(new File(imLocation))
val box = Array(
Map("xmin" -> 190, "xmax" -> 850, "ymin" -> 50, "ymax" -> 450),
Map("xmin" -> 200, "xmax" -> 350, "ymin" -> 440, "ymax" -> 530)
)
val names = Array("pug", "cookie")
Image.drawBoundingBox(buf, box, Some(names), fontSizeMult = Some(1.4f))
val tempDirPath = System.getProperty("java.io.tmpdir")
ImageIO.write(buf, "png", new File(tempDirPath + "/inputImages/out2.png"))
logger.info(s"converted image stored in ${tempDirPath + "/inputImages/out2.png"}")
for (coord <- box) {
val topLeft = buf.getRGB(coord("xmin"), coord("ymin"))
val downLeft = buf.getRGB(coord("xmin"), coord("ymax"))
val topRight = buf.getRGB(coord("xmax"), coord("ymin"))
val downRight = buf.getRGB(coord("xmax"), coord("ymax"))
require(downLeft == downRight)
require(topRight == downRight)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also clean up the file we created in the unit test here ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is created under the temp folder, system will clean it up in a certain period

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ After the previous steps, you should be able to run the code using the following
From the `scala-package/examples/scripts/infer/objectdetector/` folder run:

```bash
./run_ssd_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
./run_ssd_java_example.sh ../models/resnet50_ssd/resnet50_ssd_model ../images/dog.jpg ../images
```

**Notes**:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
import org.apache.mxnet.infer.javaapi.ObjectDetector;

// scalastyle:off
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
// scalastyle:on

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.*;

import java.io.File;

Expand Down Expand Up @@ -128,22 +127,34 @@ public static void main(String[] args) {
try {
Shape inputShape = new Shape(new int[]{1, 3, 512, 512});
Shape outputShape = new Shape(new int[]{1, 6132, 6});


int width = inputShape.get(2);
int height = inputShape.get(3);

StringBuilder outputStr = new StringBuilder().append("\n");

List<List<ObjectDetectorOutput>> output
= runObjectDetectionSingle(mdprefixDir, imgPath, context);


// Creating Bounding box material
BufferedImage buf = ImageIO.read(new File(imgPath));
int width = buf.getWidth();
int height = buf.getHeight();
List<Map<String, Integer>> boxes = new ArrayList<>();
List<String> names = new ArrayList<>();
for (List<ObjectDetectorOutput> ele : output) {
for (ObjectDetectorOutput i : ele) {
outputStr.append("Class: " + i.getClassName() + "\n");
outputStr.append("Probabilties: " + i.getProbability() + "\n");

List<Float> coord = Arrays.asList(i.getXMin() * width,
i.getXMax() * height, i.getYMin() * width, i.getYMax() * height);
names.add(i.getClassName());
Map<String, Integer> map = new HashMap<>();
float xmin = i.getXMin() * width;
float xmax = i.getXMax() * width;
float ymin = i.getYMin() * height;
float ymax = i.getYMax() * height;
List<Float> coord = Arrays.asList(xmin, xmax, ymin, ymax);
map.put("xmin", (int) xmin);
map.put("xmax", (int) xmax);
map.put("ymin", (int) ymin);
map.put("ymax", (int) ymax);
boxes.add(map);
StringBuilder sb = new StringBuilder();
for (float c : coord) {
sb.append(", ").append(c);
Expand All @@ -152,7 +163,12 @@ public static void main(String[] args) {
}
}
logger.info(outputStr.toString());


// Covert to image
Image.drawBoundingBox(buf, boxes, names);
File outputFile = new File("boundingImage.png");
ImageIO.write(buf, "png", outputFile);

List<List<List<ObjectDetectorOutput>>> outputList =
runObjectDetectionBatch(mdprefixDir, imgDir, context);

Expand All @@ -177,7 +193,6 @@ public static void main(String[] args) {
}
}
logger.info(outputStr.toString());

} catch (Exception e) {
logger.error(e.getMessage(), e);
parser.printUsage(System.err);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class ObjectDetector(modelPathPrefix: String,
if (topK.isDefined) {
var sortedIndices = predictResult.zipWithIndex.sortBy(-_._1(1)).map(_._2)
sortedIndices = sortedIndices.take(topK.get)
// takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax
// takeRight(5) would provide the output as Array[Accuracy, Xmin, Ymin, Xmax, Ymax]
result = sortedIndices.map(idx
=> (synset(predictResult(idx)(0).toInt),
predictResult(idx).takeRight(5))).toIndexedSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ class ObjectDetectorOutput (className: String, args: Array[Float]){
*
* @return Float of the max X coordinate for the object bounding box
*/
def getXMax: Float = args(2)
def getXMax: Float = args(3)

/**
* Gets the minimum Y coordinate for the bounding box containing the predicted object.
*
* @return Float of the min Y coordinate for the object bounding box
*/
def getYMin: Float = args(3)
def getYMin: Float = args(2)

/**
* Gets the maximum Y coordinate for the bounding box containing the predicted object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ public void testConstructor() {
Assert.assertEquals(odOutput.getClassName(), predictedClassName);
Assert.assertEquals("Threshold not matching", odOutput.getProbability(), 0f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMin(), 1f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 2f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 3f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getXMax(), 3f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getYMin(), 2f, delta);
Assert.assertEquals("Threshold not matching", odOutput.getYMax(), 4f, delta);

}
Expand Down