Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Address comments and add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Kedar Bellare committed Dec 22, 2018
1 parent fb7ee7f commit a15a098
Show file tree
Hide file tree
Showing 17 changed files with 306 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ Run image classification using clojure infer package.
```
$ chmod +x scripts/get_resnet_18_data.sh
$ ./scripts/get_resnet_18_data.sh
$
$ lein run -- --help
$ lein run -- -m models/resnet-18/resnet-18 -i images/kitten.jpg -d images/
$
$ lein uberjar
$ java -jar target/imageclassifier-0.1.0-SNAPSHOT-standalone.jar --help
$ java -jar target/imageclassifier-0.1.0-SNAPSHOT-standalone.jar \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# specific language governing permissions and limitations
# under the License.

set -e
set -evx

MXNET_ROOT=$(cd "$(dirname $0)/.."; pwd)

Expand All @@ -26,19 +26,20 @@ data_path=$MXNET_ROOT/models/resnet-18/
image_path=$MXNET_ROOT/images/

if [ ! -d "$data_path" ]; then
mkdir -p "$data_path"
mkdir -p "$data_path"
fi

if [ ! -d "$image_path" ]; then
mkdir -p "$image_path"
mkdir -p "$image_path"
fi

if [ ! -f "$data_path/resnet-18-0000.params" ]; then
wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-symbol.json -P $data_path
wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-0000.params -P $data_path
wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/synset.txt -P $data_path
wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-symbol.json -P $data_path
wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-0000.params -P $data_path
wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/synset.txt -P $data_path
fi

if [ ! -f "$image_path/kitten.jpg" ]; then
wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/kitten.jpg -P $image_path
wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/kitten.jpg -P $image_path
wget https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg -P $image_path
fi
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

(def cli-options
[["-m" "--model-path-prefix PREFIX" "Model path prefix"
:default "models/resnet-152/resnet-152"
:default "models/resnet-18/resnet-18"
:validate [#(check-valid-file (str % "-symbol.json"))
"Model path prefix is invalid"]]
["-i" "--input-image IMAGE" "Input image"
Expand Down Expand Up @@ -80,7 +80,9 @@
factory (infer/model-factory model-path-prefix descriptors)
classifier (infer/create-image-classifier
factory {:contexts [(context/default-context)]})]
(println "Classifying a single image")
(print-predictions (classify-single-image classifier input-image))
(println "Classifying images in a directory")
(doseq [predictions (classify-images-in-dir classifier input-dir)]
(print-predictions predictions))))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ Run object detection on images using clojure infer package.
```
$ chmod +x scripts/get_ssd_data.sh
$ ./scripts/get_ssd_data.sh
$
$ lein run -- --help
$ lein run -- -m models/resnet50_ssd/resnet50_ssd_model -i images/dog.jpg -d images/
$
$ lein uberjar
$ java -jar target/objectdetector-0.1.0-SNAPSHOT-standalone.jar --help
$ java -jar target/objectdetector-0.1.0-SNAPSHOT-standalone.jar \
-m models/resnet50_ssd/resnet50_ssd_model -i images/kitten.jpg -d images/
-m models/resnet50_ssd/resnet50_ssd_model -i images/dog.jpg -d images/
```
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
:validate [#(check-valid-file (str % "-symbol.json"))
"Model path prefix is invalid"]]
["-i" "--input-image IMAGE" "Input image"
:default "images/kitten.jpg"
:default "images/dog.jpg"
:validate [check-valid-file "Input file not found"]]
["-d" "--input-dir IMAGE_DIR" "Input directory"
:default "images/"
Expand Down Expand Up @@ -106,7 +106,9 @@
detector (infer/create-object-detector
factory
{:contexts [(context/default-context)]})]
(println "Object detection on a single image")
(print-predictions (detect-single-image detector input-image) width height)
(println "Object detection on images in a directory")
(doseq [predictions (detect-images-in-dir detector input-dir)]
(print-predictions predictions width height))))

Expand Down
4 changes: 4 additions & 0 deletions contrib/clojure-package/examples/infer/predictor/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ Run model prediction using clojure infer package.
```
$ chmod +x scripts/get_resnet_18_data.sh
$ ./scripts/get_resnet_18_data.sh
$
$ lein run -- --help
$ lein run -- -m models/resnet-18/resnet-18 -i images/kitten.jpg
$
$ lein uberjar
$ java -jar target/predictor-0.1.0-SNAPSHOT-standalone.jar --help
$ java -jar target/predictor-0.1.0-SNAPSHOT-standalone.jar \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# specific language governing permissions and limitations
# under the License.

set -e
set -evx

MXNET_ROOT=$(cd "$(dirname $0)/.."; pwd)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

(def cli-options
[["-m" "--model-path-prefix PREFIX" "Model path prefix"
:default "models/resnet-152/resnet-152"
:default "models/resnet-18/resnet-18"
:validate [#(check-valid-file (str % "-symbol.json"))
"Model path prefix is invalid"]]
["-i" "--input-image IMAGE" "Image path"
Expand All @@ -53,12 +53,10 @@
"Preprocesses image to make it ready for prediction"
[image-path width height]
(-> image-path
image/read-image
(image/resize-image width height)
; HWC -> CHW
(ndarray/transpose (shape/->shape [2 0 1]))
(ndarray/expand-dims 0)
(ndarray/as-type dtype/FLOAT32)))
infer/load-image-from-file
(infer/reshape-image width height)
(infer/buffered-image-to-pixels (shape/->shape [3 width height]))
(ndarray/expand-dims 0)))

(defn do-inference
"Run inference using given predictor"
Expand Down Expand Up @@ -90,8 +88,8 @@
predictor (infer/create-predictor
factory
{:contexts [(context/default-context)]})
image (preprocess input-image width height)
predictions (do-inference predictor image)
image-ndarray (preprocess input-image width height)
predictions (do-inference predictor image-ndarray)
best-prediction (postprocess model-path-prefix predictions)]
(print-prediction best-prediction)))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

(deftest predictor-test
(let [predictor (create-predictor)
image (preprocess image-file width height)
predictions (do-inference predictor image)
image-ndarray (preprocess image-file width height)
predictions (do-inference predictor image-ndarray)
best-prediction (postprocess model-path-prefix predictions)]
(is (= "n02123159 tiger cat" best-prediction))))
6 changes: 3 additions & 3 deletions contrib/clojure-package/integration-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@

set -evx

MXNET_HOME=${PWD}
MXNET_HOME=$(cd "$(dirname $0)/../.."; pwd)
EXAMPLES_HOME=${MXNET_HOME}/contrib/clojure-package/examples
#cd ${MXNET_HOME}/contrib/clojure-package
#lein test
#lein cloverage --codecov
for i in `find ${EXAMPLES_HOME} -name test` ; do
cd ${i} && lein test
for test_dir in `find ${EXAMPLES_HOME} -name test` ; do
cd ${test_dir} && lein test
done
38 changes: 38 additions & 0 deletions contrib/clojure-package/scripts/infer/get_resnet_18_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/bin/bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

set -evx

if [ ! -z "$MXNET_HOME" ]; then
data_path="$MXNET_HOME/data"
else
MXNET_ROOT=$(cd "$(dirname $0)/../.."; pwd)
data_path="$MXNET_ROOT/data"
fi

if [ ! -d "$data_path" ]; then
mkdir -p "$data_path"
fi

resnet_18_data_path="$data_path/resnet-18"
if [ ! -f "$resnet_18_data_path/resnet-18-0000.params" ]; then
wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-symbol.json -P $resnet_18_data_path
wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-0000.params -P $resnet_18_data_path
wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/synset.txt -P $resnet_18_data_path
fi
39 changes: 39 additions & 0 deletions contrib/clojure-package/scripts/infer/get_ssd_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


set -evx

if [ ! -z "$MXNET_HOME" ]; then
data_path="$MXNET_HOME/data"
else
MXNET_ROOT=$(cd "$(dirname $0)/../.."; pwd)
data_path="$MXNET_ROOT/data"
fi

if [ ! -d "$data_path" ]; then
mkdir -p "$data_path"
fi

resnet50_ssd_data_path="$data_path/resnet50_ssd"
if [ ! -f "$resnet50_ssd_data_path/resnet50_ssd_model-0000.params" ]; then
wget https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-symbol.json -P $resnet50_ssd_data_path
wget https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-0000.params -P $resnet50_ssd_data_path
wget https://s3.amazonaws.com/model-server/models/resnet50_ssd/synset.txt -P $resnet50_ssd_data_path
fi
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns org.apache.clojure-mxnet.infer.imageclassifier-test
(:require [org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.infer :as infer]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.layout :as layout]
[clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
[clojure.test :refer :all]))

(def model-dir "data/")
(def model-path-prefix (str model-dir "resnet-18/resnet-18"))

(when-not (.exists (io/file (str model-path-prefix "-symbol.json")))
(sh "./scripts/infer/get_resnet_18_data.sh"))

(defn create-classifier []
(let [descriptors [(mx-io/data-desc {:name "data"
:shape [1 3 224 224]
:layout layout/NCHW
:dtype dtype/FLOAT32})]
factory (infer/model-factory model-path-prefix descriptors)]
(infer/create-image-classifier factory)))

(deftest test-single-classification
(let [classifier (create-classifier)
image (infer/load-image-from-file "test/test-images/kitten.jpg")
[predictions] (infer/classify-image classifier image 5)]
(is (some? predictions))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
(is (every? #(string? (first %)) predictions))
(is (every? #(float? (second %)) predictions))
(is (every? #(< 0 (second %) 1) predictions))
(is (= ["n02123159 tiger cat"
"n02124075 Egyptian cat"
"n02123045 tabby, tabby cat"
"n02127052 lynx, catamount"
"n02128757 snow leopard, ounce, Panthera uncia"]
(map first predictions)))))

(deftest test-batch-classification
(let [classifier (create-classifier)
image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
"test/test-images/Pug-Cookie.jpg"])
batch-predictions (infer/classify-image-batch classifier image-batch 5)
predictions (first batch-predictions)]
(is (some? batch-predictions))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
(is (every? #(string? (first %)) predictions))
(is (every? #(float? (second %)) predictions))
(is (every? #(< 0 (second %) 1) predictions))))
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns org.apache.clojure-mxnet.infer.objectdetector-test
(:require [org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.infer :as infer]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.layout :as layout]
[clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
[clojure.test :refer :all]))

(def model-dir "data/")
(def model-path-prefix (str model-dir "resnet50_ssd/resnet50_ssd_model"))

(when-not (.exists (io/file (str model-path-prefix "-symbol.json")))
(sh "./scripts/infer/get_ssd_data.sh"))

(defn create-detector []
(let [descriptors [(mx-io/data-desc {:name "data"
:shape [1 3 512 512]
:layout layout/NCHW
:dtype dtype/FLOAT32})]
factory (infer/model-factory model-path-prefix descriptors)]
(infer/create-object-detector factory)))

(deftest test-single-detection
(let [detector (create-detector)
image (infer/load-image-from-file "test/test-images/kitten.jpg")
[predictions] (infer/detect-objects detector image 5)]
(is (some? predictions))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
(is (every? #(string? (first %)) predictions))
(is (every? #(= 5 (count (second %))) predictions))
(is (every? #(< 0 (first (second %)) 1) predictions))
(is (= "cat" (first (first predictions))))))

(deftest test-batch-detection
(let [detector (create-detector)
image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
"test/test-images/Pug-Cookie.jpg"])
batch-predictions (infer/detect-objects-batch detector image-batch 5)
predictions (first batch-predictions)]
(is (some? batch-predictions))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
(is (every? #(string? (first %)) predictions))
(is (every? #(= 5 (count (second %))) predictions))
(is (every? #(< 0 (first (second %)) 1) predictions))))
Loading

0 comments on commit a15a098

Please sign in to comment.