Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudh2290 committed Apr 13, 2019
2 parents 6574e91 + 5fc5c27 commit 716715e
Show file tree
Hide file tree
Showing 73 changed files with 1,365 additions and 499 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
:aliases {"run-detector" ["run" "--" "-m" "models/resnet50_ssd/resnet50_ssd_model" "-i" "images/dog.jpg" "-d" "images/"]}
:dependencies [[org.clojure/clojure "1.9.0"]
[org.clojure/tools.cli "0.4.1"]
[origami "4.0.0-3"]
[org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"]]
:main ^:skip-aot infer.objectdetector-example
:profiles {:uberjar {:aot :all}})

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
(ns infer.objectdetector-example
(:require [org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.image :as image]
[org.apache.clojure-mxnet.infer :as infer]
[org.apache.clojure-mxnet.layout :as layout]
[clojure.java.io :as io]
[infer.draw :as draw]
[clojure.string :refer [join]]
[clojure.string :as string]
[clojure.tools.cli :refer [parse-opts]])
(:gen-class))
(:gen-class)
(:import (javax.imageio ImageIO)
(java.io File)))

(defn check-valid-dir
"Check that the input directory exists"
Expand Down Expand Up @@ -54,35 +56,44 @@
:validate [check-valid-dir "Input directory not found"]]
["-h" "--help"]])

(defn result->map [{:keys [class prob x-min y-min x-max y-max]}]
(hash-map
:label class
:confidence (int (* 100 prob))
:top-left [x-min y-min]
:bottom-right [x-max y-max]))

(defn print-results [results]
(doseq [_r results]
(println (format "Class: %s Confidence=%s Coords=(%s, %s)"
(_r :label)
(_r :confidence)
(_r :top-left)
(_r :bottom-right)))))
(defn process-result! [output-dir image-path predictions]
(println "looking at image" image-path)
(println "predictions: " predictions)
(let [buf (ImageIO/read (new File image-path))
width (.getWidth buf)
height (.getHeight buf)
names (mapv :class predictions)
coords (mapv (fn [prediction]
(-> prediction
(update :x-min #(* width %))
(update :x-max #(* width %))
(update :y-min #(* height %))
(update :y-max #(* height %))))
predictions)
new-img (-> (ImageIO/read (new File image-path))
(image/draw-bounding-box! coords
{:stroke 2
:names (mapv #(str (:class %) "-" (:prob %))
predictions)
:transparency 0.5

:font-size-mult 1.0}))]
(->> (string/split image-path #"\/")
last
(io/file output-dir)
(ImageIO/write new-img "jpg"))))

(defn process-results [images results output-dir]
(dotimes [i (count images)]
(let [image (nth images i) _results (map result->map (nth results i))]
(println "processing: " image)
(print-results _results)
(draw/draw-bounds image _results output-dir))))
(doall (map (partial process-result! output-dir) images results)))

(defn detect-single-image
"Detect objects in a single image and print top-5 predictions"
([detector input-dir] (detect-single-image detector input-dir "results"))
([detector input-image output-dir]
(.mkdir (io/file output-dir))
(let [image (infer/load-image-from-file input-image)
topk 5
topk 3
res (infer/detect-objects detector image topk)
]
(process-results
Expand All @@ -109,7 +120,7 @@
(apply concat
(for [image-files image-file-batches]
(let [image-batch (infer/load-image-paths image-files)
topk 5
topk 3
res (infer/detect-objects-batch detector image-batch topk) ]
(process-results
image-files
Expand Down Expand Up @@ -143,5 +154,5 @@
(parse-opts args cli-options)]
(cond
(:help options) (println summary)
(some? errors) (println (join "\n" errors))
(some? errors) (println (string/join "\n" errors))
:else (run-detector options))))
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@
{:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
(clojure.pprint/pprint predictions)
(is (some? predictions))
(is (= 5 (count predictions)))
(is (= 3 (count predictions)))
(is (string? class))
(is (< 0.8 prob))
(is (every? #(< 0 % 1) [x-min x-max y-min y-max]))
(is (= #{"dog" "person" "bicycle" "car"} (set (mapv :class predictions))))))
(is (= #{"dog" "bicycle" "car"} (set (mapv :class predictions))))))

(deftest test-batch-detection
(let [detector (create-detector)
Expand All @@ -60,7 +60,7 @@
predictions (first batch-predictions)
{:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)]
(is (some? batch-predictions))
(is (= 5 (count predictions)))
(is (= 3 (count predictions)))
(is (string? class))
(is (< 0.8 prob))
(println [x-min x-max y-min y-max])
Expand Down
2 changes: 1 addition & 1 deletion contrib/clojure-package/integration-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ lein install
# then run through the examples
EXAMPLES_HOME=${MXNET_HOME}/contrib/clojure-package/examples
# use AWK pattern for blacklisting
TEST_CASES=`find ${EXAMPLES_HOME} -name test | awk '!/dontselect1|cnn-text-classification|gan|neural-style|infer|pre-trained-models/'`
TEST_CASES=`find ${EXAMPLES_HOME} -name test | awk '!/dontselect1|cnn-text-classification|gan|neural-style|pre-trained-models/'`
for i in $TEST_CASES ; do
cd ${i} && lein test
done
26 changes: 13 additions & 13 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/image.clj
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,11 @@
(Image/toImage input))

(s/def ::buffered-image #(instance? BufferedImage %))
(s/def ::xmin integer?)
(s/def ::xmax integer?)
(s/def ::ymin integer?)
(s/def ::ymax integer?)
(s/def ::coordinate (s/keys :req-un [::xmin ::xmax ::ymin ::ymax]))
(s/def ::x-min number?)
(s/def ::x-max number?)
(s/def ::y-min number?)
(s/def ::y-max number?)
(s/def ::coordinate (s/keys :req-un [::x-min ::x-max ::y-min ::y-max]))
(s/def ::coordinates (s/coll-of ::coordinate))
(s/def ::names (s/nilable (s/coll-of string?)))
(s/def ::stroke (s/and integer? pos?))
Expand All @@ -217,11 +217,11 @@

(defn- convert-coordinate
"Convert bounding box coordinate to Scala correct types."
[{:keys [xmin xmax ymin ymax]}]
{:xmin (int xmin)
:xmax (int xmax)
:ymin (int ymin)
:ymax (int ymax)})
[{:keys [x-min x-max y-min y-max]}]
{:xmin (int x-min)
:xmax (int x-max)
:ymin (int y-min)
:ymax (int y-max)})

(defn draw-bounding-box!
"Draw bounding boxes on `buffered-image` and Mutate the input image.
Expand All @@ -233,9 +233,9 @@
`transparency`: float in (0.0, 1.0) - Transparency of the bounding box
returns: Modified `buffered-image`
Ex:
(draw-bounding-box! img [{:xmin 0 :xmax 100 :ymin 0 :ymax 100}])
(draw-bounding-box! [{:xmin 190 :xmax 850 :ymin 50 :ymax 450}
{:xmin 200 :xmax 350 :ymin 440 :ymax 530}]
(draw-bounding-box! img [{:x-min 0 :x-max 100 :y-min 0 :y-max 100}])
(draw-bounding-box! [{:x-min 190 :x-max 850 :y-min 50 :y-max 450}
{:x-min 200 :x-max 350 :y-min 440 :y-max 530}]
{:stroke 2
:names [\"pug\" \"cookie\"]
:transparency 0.8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
[org.apache.clojure-mxnet.ndarray :as ndarray]
[clojure.java.io :as io]
[clojure.test :refer :all])
(:import (javax.imageio ImageIO)))
(:import (javax.imageio ImageIO)
(java.io File)))

(def tmp-dir (System/getProperty "java.io.tmpdir"))
(def image-path (.getAbsolutePath (io/file tmp-dir "Pug-Cookie.jpg")))
Expand Down Expand Up @@ -76,4 +77,15 @@
(let [img-arr (image/read-image image-path)
resized-arr (image/resize-image img-arr 224 224)
new-img (image/to-image resized-arr)]
(is (= true (ImageIO/write new-img "png" (io/file tmp-dir "out.png"))))))
(is (ImageIO/write new-img "png" (io/file tmp-dir "out.png")))))

(deftest test-draw-bounding-box!
(let [orig-img (ImageIO/read (new File image-path))
new-img (-> orig-img
(image/draw-bounding-box! [{:x-min 190 :x-max 850 :y-min 50 :y-max 450}
{:x-min 200 :x-max 350 :y-min 440 :y-max 530}]
{:stroke 2
:names ["pug" "cookie"]
:transparency 0.8
:font-size-mult 2.0}))]
(is (ImageIO/write new-img "png" (io/file tmp-dir "out.png")))))
1 change: 1 addition & 0 deletions example/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ The following models have been tested on Linux systems.
|[ResNet152-V2](#8)|[MXNet ModelZoo](http://data.mxnet.io/models/imagenet/resnet/152-layers/)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|76.76%/93.03%|76.48%/92.96%|
|[Inception-BN](#9)|[MXNet ModelZoo](http://data.mxnet.io/models/imagenet/inception-bn/)|[Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec)|72.09%/90.60%|72.00%/90.53%|
| [SSD-VGG16](#10) | [example/ssd](https://github.com/apache/incubator-mxnet/tree/master/example/ssd) | VOC2007/2012 | 0.8366 mAP | 0.8364 mAP |
| [SSD-VGG16](#10) | [example/ssd](https://github.com/apache/incubator-mxnet/tree/master/example/ssd) | COCO2014 | 0.2552 mAP | 0.253 mAP |

<h3 id='3'>ResNet50-V1</h3>

Expand Down
56 changes: 51 additions & 5 deletions example/ssd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ remarkable traits of MXNet.
Due to the permission issue, this example is maintained in this [repository](https://github.com/zhreshold/mxnet-ssd) separately. You can use the link regarding specific per example [issues](https://github.com/zhreshold/mxnet-ssd/issues).

### What's new
* Support training and inference on COCO dataset. Int8 inference achieves 0.253 mAP on CPU with MKL-DNN backend, which is a comparable accuracy to FP32 (0.2552 mAP).
* Support uint8 inference on CPU with MKL-DNN backend. Uint8 inference achieves 0.8364 mAP, which is a comparable accuracy to FP32 (0.8366 mAP).
* Added live camera capture and detection display (run with --camera flag). Example:
`./demo.py --camera --cpu --frame-resize 0.5`
Expand Down Expand Up @@ -119,9 +120,9 @@ You can use `./demo.py --camera` to use a video capture device with opencv such
will open a window that will display the camera output together with the detections. You can play
with the detection threshold to get more or less detections.

### Train the model
### Train the model on VOC
* Note that we recommend to use gluon-cv to train the model, please refer to [gluon-cv ssd](https://gluon-cv.mxnet.io/build/examples_detection/train_ssd_voc.html).
This example only covers training on Pascal VOC dataset. Other datasets should
This example only covers training on Pascal VOC or MS COCO dataset. Other datasets should
be easily supported by adding subclass derived from class `Imdb` in `dataset/imdb.py`.
See example of `dataset/pascal_voc.py` for details.
* Download the converted pretrained `vgg16_reduced` model [here](https://github.com/zhreshold/mxnet-ssd/releases/download/v0.2-beta/vgg16_reduced.zip), unzip `.param` and `.json` files
Expand Down Expand Up @@ -166,16 +167,53 @@ Check `python train.py --help` for more training options. For example, if you ha
python train.py --gpus 0,1,2,3 --batch-size 32
```

### Train the model on COCO
* Download the COCO2014 dataset, skip this step if you already have one.
```
cd /path/to/where_you_store_datasets/
wget http://images.cocodataset.org/zips/train2014.zip
wget http://images.cocodataset.org/zips/val2014.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
# Extract the data.
unzip train2014.zip
unzip val2014.zip
unzip annotations_trainval2014.zip
```
* We are going to use `train2014,valminusminival2014` set in COCO2014 for training and `minival2014` for evaluation as a common strategy.
* Then link `COCO2014` folder to `data/coco` by default:
```
ln -s /path/to/COCO2014 /path/to/incubator-mxnet/example/ssd/data/coco
```
Use hard link instead of copy could save us a bit disk space.
* Create packed binary file for faster training:
```
# cd /path/to/incubator-mxnet/example/ssd
bash tools/prepare_coco.sh
# or if you are using windows
python tools/prepare_dataset.py --dataset coco --set train2014,valminusminival2014 --target ./data/train.lst --root ./data/coco
python tools/prepare_dataset.py --dataset coco --set minival2014 --target ./data/val.lst --root ./data/coco --no-shuffle
```
* Start training:
```
# cd /path/to/incubator-mxnet/example/ssd
python train.py --label-width=560 --num-class=80 --class-names=./dataset/names/coco_label --pretrained="" --num-example=117265 --batch-size=64
```

### Evalute trained model
Make sure you have val.rec as validation dataset. It's the same one as used in training. Use:
```
# cd /path/to/incubator-mxnet/example/ssd
python evaluate.py --gpus 0,1 --batch-size 128 --epoch 0
# Evaluate on COCO dataset
python evaluate.py --gpus 0,1 --batch-size 128 --epoch 0 --num-class=80 --class-names=./dataset/names/mscoco.names
```

### Quantize model

Follow the [Train instructions](https://github.com/apache/incubator-mxnet/tree/master/example/ssd#train-the-model) to train a FP32 `SSD-VGG16_reduced_300x300` model based on Pascal VOC dataset. You can also download our [SSD-VGG16 pre-trained model](http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/ssd_vgg16_reduced_300-dd479559.zip) and [packed binary data](http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/ssd-val-fc19a535.zip). Create `model` and `data` directories if they're not exist, extract the zip files, then rename the uncompressed files as follows (eg, rename `ssd-val-fc19a535.idx` to `val.idx`, `ssd-val-fc19a535.lst` to `val.lst`, `ssd-val-fc19a535.rec` to `val.rec`, `ssd_vgg16_reduced_300-dd479559.params` to `ssd_vgg16_reduced_300-0000.params`, `ssd_vgg16_reduced_300-symbol-dd479559.json` to `ssd_vgg16_reduced_300-symbol.json`.)
To quantize a model on VOC dataset, follow the [Train instructions](https://github.com/apache/incubator-mxnet/tree/master/example/ssd#train-the-model-on-VOC) to train a FP32 `SSD-VGG16_reduced_300x300` model based on Pascal VOC dataset. You can also download our [SSD-VGG16 pre-trained model](http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/ssd_vgg16_reduced_300-dd479559.zip) and [packed binary data](http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/ssd-val-fc19a535.zip). Create `model` and `data` directories if they're not exist, extract the zip files, then rename the uncompressed files as follows (eg, rename `ssd-val-fc19a535.idx` to `val.idx`, `ssd-val-fc19a535.lst` to `val.lst`, `ssd-val-fc19a535.rec` to `val.rec`, `ssd_vgg16_reduced_300-dd479559.params` to `ssd_vgg16_reduced_300-0000.params`, `ssd_vgg16_reduced_300-symbol-dd479559.json` to `ssd_vgg16_reduced_300-symbol.json`.)

To quantize a model on COCO dataset, follow the [Train instructions](https://github.com/apache/incubator-mxnet/tree/master/example/ssd#train-the-model-on-COCO) to train a FP32 `SSD-VGG16_reduced_300x300` model based on COCO dataset. You can also download our [SSD-VGG16 pre-trained model](http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/ssd_vgg16_reduced_300-7fedd4ad.zip) and [packed binary data](http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/ssd_coco-val-e91096e8.zip). Create `model` and `data` directories if they're not exist, extract the zip files, then rename the uncompressed files as follows (eg, rename `ssd_coco-val-e91096e8.idx` to `val.idx`, `ssd_coco-val-e91096e8.lst` to `val.lst`, `ssd_coco-val-e91096e8.rec` to `val.rec`, `ssd_vgg16_reduced_300-7fedd4ad.params` to `ssd_vgg16_reduced_300-0000.params`, `ssd_vgg16_reduced_300-symbol-7fedd4ad.json` to `ssd_vgg16_reduced_300-symbol.json`.)

```
data/
Expand All @@ -199,12 +237,20 @@ After quantization, INT8 models will be saved in `model/` dictionary. Use the f
# USE MKLDNN AS SUBGRAPH BACKEND
export MXNET_SUBGRAPH_BACKEND=MKLDNN
# Launch FP32 Inference
# Launch FP32 Inference on VOC dataset
python evaluate.py --cpu --num-batch 10 --batch-size 224 --deploy --prefix=./model/ssd_
# Launch INT8 Inference
# Launch INT8 Inference on VOC dataset
python evaluate.py --cpu --num-batch 10 --batch-size 224 --deploy --prefix=./model/cqssd_
# Launch FP32 Inference on COCO dataset
python evaluate.py --cpu --num-batch 10 --batch-size 224 --deploy --prefix=./model/ssd_ --num-class=80 --class-names=./dataset/names/mscoco.names
# Launch INT8 Inference on COCO dataset
python evaluate.py --cpu --num-batch 10 --batch-size 224 --deploy --prefix=./model/cqssd_ --num-class=80 --class-names=./dataset/names/mscoco.names
# Launch dummy data Inference
python benchmark_score.py --deploy --prefix=./model/ssd_
python benchmark_score.py --deploy --prefix=./model/cqssd_
Expand Down
Loading

0 comments on commit 716715e

Please sign in to comment.