forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Port of scala infer package to clojure (apache#13595)
* Port of scala infer package to clojure * Add inference examples * Fix project.clj * Update code for integration tests * Address comments and add unit tests * Add specs and simplify interface * Minor nit * Update README
- Loading branch information
1 parent
45ab188
commit 158c67a
Showing
32 changed files
with
1,556 additions
and
6 deletions.
There are no files selected for viewing
12 changes: 12 additions & 0 deletions
12
contrib/clojure-package/examples/infer/imageclassifier/.gitignore
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
/target | ||
/classes | ||
/checkouts | ||
/images | ||
pom.xml | ||
pom.xml.asc | ||
*.jar | ||
*.class | ||
/.lein-* | ||
/.nrepl-port | ||
.hgignore | ||
.hg/ |
24 changes: 24 additions & 0 deletions
24
contrib/clojure-package/examples/infer/imageclassifier/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# imageclassifier | ||
|
||
Run image classification using clojure infer package. | ||
|
||
## Installation | ||
|
||
Before you run this example, make sure that you have the clojure package installed. | ||
In the main clojure package directory, do `lein install`. Then you can run | ||
`lein install` in this directory. | ||
|
||
## Usage | ||
|
||
``` | ||
$ chmod +x scripts/get_resnet_18_data.sh | ||
$ ./scripts/get_resnet_18_data.sh | ||
$ | ||
$ lein run -- --help | ||
$ lein run -- -m models/resnet-18/resnet-18 -i images/kitten.jpg -d images/ | ||
$ | ||
$ lein uberjar | ||
$ java -jar target/imageclassifier-0.1.0-SNAPSHOT-standalone.jar --help | ||
$ java -jar target/imageclassifier-0.1.0-SNAPSHOT-standalone.jar \ | ||
-m models/resnet-18/resnet-18 -i images/kitten.jpg -d images/ | ||
``` |
25 changes: 25 additions & 0 deletions
25
contrib/clojure-package/examples/infer/imageclassifier/project.clj
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
;; | ||
;; Licensed to the Apache Software Foundation (ASF) under one or more | ||
;; contributor license agreements. See the NOTICE file distributed with | ||
;; this work for additional information regarding copyright ownership. | ||
;; The ASF licenses this file to You under the Apache License, Version 2.0 | ||
;; (the "License"); you may not use this file except in compliance with | ||
;; the License. You may obtain a copy of the License at | ||
;; | ||
;; http://www.apache.org/licenses/LICENSE-2.0 | ||
;; | ||
;; Unless required by applicable law or agreed to in writing, software | ||
;; distributed under the License is distributed on an "AS IS" BASIS, | ||
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
;; See the License for the specific language governing permissions and | ||
;; limitations under the License. | ||
;; | ||
|
||
(defproject imageclassifier "0.1.0-SNAPSHOT" | ||
:description "Image classification using infer with MXNet" | ||
:plugins [[lein-cljfmt "0.5.7"]] | ||
:dependencies [[org.clojure/clojure "1.9.0"] | ||
[org.clojure/tools.cli "0.4.1"] | ||
[org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"]] | ||
:main ^:skip-aot infer.imageclassifier-example | ||
:profiles {:uberjar {:aot :all}}) |
45 changes: 45 additions & 0 deletions
45
contrib/clojure-package/examples/infer/imageclassifier/scripts/get_resnet_18_data.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#!/bin/bash | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
set -evx | ||
|
||
MXNET_ROOT=$(cd "$(dirname $0)/.."; pwd) | ||
|
||
data_path=$MXNET_ROOT/models/resnet-18/ | ||
|
||
image_path=$MXNET_ROOT/images/ | ||
|
||
if [ ! -d "$data_path" ]; then | ||
mkdir -p "$data_path" | ||
fi | ||
|
||
if [ ! -d "$image_path" ]; then | ||
mkdir -p "$image_path" | ||
fi | ||
|
||
if [ ! -f "$data_path/resnet-18-0000.params" ]; then | ||
wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-symbol.json -P $data_path | ||
wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-0000.params -P $data_path | ||
wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/synset.txt -P $data_path | ||
fi | ||
|
||
if [ ! -f "$image_path/kitten.jpg" ]; then | ||
wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/kitten.jpg -P $image_path | ||
wget https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg -P $image_path | ||
fi |
44 changes: 44 additions & 0 deletions
44
contrib/clojure-package/examples/infer/imageclassifier/scripts/get_resnet_data.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#!/bin/bash | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
set -e | ||
|
||
MXNET_ROOT=$(cd "$(dirname $0)/.."; pwd) | ||
|
||
data_path=$MXNET_ROOT/models/resnet-152/ | ||
|
||
image_path=$MXNET_ROOT/images/ | ||
|
||
if [ ! -d "$data_path" ]; then | ||
mkdir -p "$data_path" | ||
fi | ||
|
||
if [ ! -d "$image_path" ]; then | ||
mkdir -p "$image_path" | ||
fi | ||
|
||
if [ ! -f "$data_path/resnet-152-0000.params" ]; then | ||
wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/resnet-152-0000.params -P $data_path | ||
wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/resnet-152-symbol.json -P $data_path | ||
wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/synset.txt -P $data_path | ||
fi | ||
|
||
if [ ! -f "$image_path/kitten.jpg" ]; then | ||
wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/kitten.jpg -P $image_path | ||
fi |
95 changes: 95 additions & 0 deletions
95
contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
(ns infer.imageclassifier-example | ||
(:require [org.apache.clojure-mxnet.context :as context] | ||
[org.apache.clojure-mxnet.dtype :as dtype] | ||
[org.apache.clojure-mxnet.infer :as infer] | ||
[org.apache.clojure-mxnet.layout :as layout] | ||
[clojure.java.io :as io] | ||
[clojure.string :refer [join]] | ||
[clojure.tools.cli :refer [parse-opts]]) | ||
(:gen-class)) | ||
|
||
(defn check-valid-dir | ||
"Check that the input directory exists" | ||
[input-dir] | ||
(let [dir (io/file input-dir)] | ||
(and | ||
(.exists dir) | ||
(.isDirectory dir)))) | ||
|
||
(defn check-valid-file | ||
"Check that the file exists" | ||
[input-file] | ||
(.exists (io/file input-file))) | ||
|
||
(def cli-options | ||
[["-m" "--model-path-prefix PREFIX" "Model path prefix" | ||
:default "models/resnet-18/resnet-18" | ||
:validate [#(check-valid-file (str % "-symbol.json")) | ||
"Model path prefix is invalid"]] | ||
["-i" "--input-image IMAGE" "Input image" | ||
:default "images/kitten.jpg" | ||
:validate [check-valid-file "Input file not found"]] | ||
["-d" "--input-dir IMAGE_DIR" "Input directory" | ||
:default "images/" | ||
:validate [check-valid-dir "Input directory not found"]] | ||
["-h" "--help"]]) | ||
|
||
(defn print-predictions | ||
"Print image classifier predictions for the given input file" | ||
[predictions] | ||
(println (apply str (repeat 80 "="))) | ||
(doseq [[label probability] predictions] | ||
(println (format "Class: %s Probability=%.8f" label probability))) | ||
(println (apply str (repeat 80 "=")))) | ||
|
||
(defn classify-single-image | ||
"Classify a single image and print top-5 predictions" | ||
[classifier input-image] | ||
(let [image (infer/load-image-from-file input-image) | ||
topk 5 | ||
[predictions] (infer/classify-image classifier image topk)] | ||
predictions)) | ||
|
||
(defn classify-images-in-dir | ||
"Classify all jpg images in the directory" | ||
[classifier input-dir] | ||
(let [batch-size 20 | ||
image-file-batches (->> input-dir | ||
io/file | ||
file-seq | ||
(filter #(.isFile %)) | ||
(filter #(re-matches #".*\.jpg$" (.getPath %))) | ||
(mapv #(.getPath %)) | ||
(partition-all batch-size))] | ||
(apply | ||
concat | ||
(for [image-files image-file-batches] | ||
(let [image-batch (infer/load-image-paths image-files) | ||
topk 5] | ||
(infer/classify-image-batch classifier image-batch topk)))))) | ||
|
||
(defn run-classifier | ||
"Runs an image classifier based on options provided" | ||
[options] | ||
(let [{:keys [model-path-prefix input-image input-dir]} options | ||
descriptors [{:name "data" | ||
:shape [1 3 224 224] | ||
:layout layout/NCHW | ||
:dtype dtype/FLOAT32}] | ||
factory (infer/model-factory model-path-prefix descriptors) | ||
classifier (infer/create-image-classifier | ||
factory {:contexts [(context/default-context)]})] | ||
(println "Classifying a single image") | ||
(print-predictions (classify-single-image classifier input-image)) | ||
(println "Classifying images in a directory") | ||
(doseq [predictions (classify-images-in-dir classifier input-dir)] | ||
(print-predictions predictions)))) | ||
|
||
(defn -main | ||
[& args] | ||
(let [{:keys [options summary errors] :as opts} | ||
(parse-opts args cli-options)] | ||
(cond | ||
(:help options) (println summary) | ||
(some? errors) (println (join "\n" errors)) | ||
:else (run-classifier options)))) |
69 changes: 69 additions & 0 deletions
69
...lojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
;; Licensed to the Apache Software Foundation (ASF) under one or more | ||
;; contributor license agreements. See the NOTICE file distributed with | ||
;; this work for additional information regarding copyright ownership. | ||
;; The ASF licenses this file to You under the Apache License, Version 2.0 | ||
;; (the "License"); you may not use this file except in compliance with | ||
;; the License. You may obtain a copy of the License at | ||
;; | ||
;; http://www.apache.org/licenses/LICENSE-2.0 | ||
;; | ||
;; Unless required by applicable law or agreed to in writing, software | ||
;; distributed under the License is distributed on an "AS IS" BASIS, | ||
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
;; See the License for the specific language governing permissions and | ||
;; limitations under the License. | ||
;; | ||
|
||
(ns infer.imageclassifier-example-test | ||
(:require [infer.imageclassifier-example :refer [classify-single-image | ||
classify-images-in-dir]] | ||
[org.apache.clojure-mxnet.context :as context] | ||
[org.apache.clojure-mxnet.dtype :as dtype] | ||
[org.apache.clojure-mxnet.infer :as infer] | ||
[org.apache.clojure-mxnet.layout :as layout] | ||
[clojure.java.io :as io] | ||
[clojure.java.shell :refer [sh]] | ||
[clojure.test :refer :all])) | ||
|
||
(def model-dir "models/") | ||
(def image-dir "images/") | ||
(def model-path-prefix (str model-dir "resnet-18/resnet-18")) | ||
(def image-file (str image-dir "kitten.jpg")) | ||
|
||
(when-not (.exists (io/file (str model-path-prefix "-symbol.json"))) | ||
(sh "./scripts/get_resnet_18_data.sh")) | ||
|
||
(defn create-classifier [] | ||
(let [descriptors [{:name "data" | ||
:shape [1 3 224 224] | ||
:layout layout/NCHW | ||
:dtype dtype/FLOAT32}] | ||
factory (infer/model-factory model-path-prefix descriptors)] | ||
(infer/create-image-classifier factory))) | ||
|
||
(deftest test-single-classification | ||
(let [classifier (create-classifier) | ||
predictions (classify-single-image classifier image-file)] | ||
(is (some? predictions)) | ||
(is (= 5 (count predictions))) | ||
(is (every? #(= 2 (count %)) predictions)) | ||
(is (every? #(string? (first %)) predictions)) | ||
(is (every? #(float? (second %)) predictions)) | ||
(is (every? #(< 0 (second %) 1) predictions)) | ||
(is (= ["n02123159 tiger cat" | ||
"n02124075 Egyptian cat" | ||
"n02123045 tabby, tabby cat" | ||
"n02127052 lynx, catamount" | ||
"n02128757 snow leopard, ounce, Panthera uncia"] | ||
(map first predictions))))) | ||
|
||
(deftest test-batch-classification | ||
(let [classifier (create-classifier) | ||
batch-predictions (classify-images-in-dir classifier image-dir) | ||
predictions (first batch-predictions)] | ||
(is (some? batch-predictions)) | ||
(is (= 5 (count predictions))) | ||
(is (every? #(= 2 (count %)) predictions)) | ||
(is (every? #(string? (first %)) predictions)) | ||
(is (every? #(float? (second %)) predictions)) | ||
(is (every? #(< 0 (second %) 1) predictions)))) |
12 changes: 12 additions & 0 deletions
12
contrib/clojure-package/examples/infer/objectdetector/.gitignore
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
/target | ||
/classes | ||
/checkouts | ||
/images | ||
pom.xml | ||
pom.xml.asc | ||
*.jar | ||
*.class | ||
/.lein-* | ||
/.nrepl-port | ||
.hgignore | ||
.hg/ |
24 changes: 24 additions & 0 deletions
24
contrib/clojure-package/examples/infer/objectdetector/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# objectdetector | ||
|
||
Run object detection on images using clojure infer package. | ||
|
||
## Installation | ||
|
||
Before you run this example, make sure that you have the clojure package installed. | ||
In the main clojure package directory, do `lein install`. Then you can run | ||
`lein install` in this directory. | ||
|
||
## Usage | ||
|
||
``` | ||
$ chmod +x scripts/get_ssd_data.sh | ||
$ ./scripts/get_ssd_data.sh | ||
$ | ||
$ lein run -- --help | ||
$ lein run -- -m models/resnet50_ssd/resnet50_ssd_model -i images/dog.jpg -d images/ | ||
$ | ||
$ lein uberjar | ||
$ java -jar target/objectdetector-0.1.0-SNAPSHOT-standalone.jar --help | ||
$ java -jar target/objectdetector-0.1.0-SNAPSHOT-standalone.jar \ | ||
-m models/resnet50_ssd/resnet50_ssd_model -i images/dog.jpg -d images/ | ||
``` |
25 changes: 25 additions & 0 deletions
25
contrib/clojure-package/examples/infer/objectdetector/project.clj
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
;; | ||
;; Licensed to the Apache Software Foundation (ASF) under one or more | ||
;; contributor license agreements. See the NOTICE file distributed with | ||
;; this work for additional information regarding copyright ownership. | ||
;; The ASF licenses this file to You under the Apache License, Version 2.0 | ||
;; (the "License"); you may not use this file except in compliance with | ||
;; the License. You may obtain a copy of the License at | ||
;; | ||
;; http://www.apache.org/licenses/LICENSE-2.0 | ||
;; | ||
;; Unless required by applicable law or agreed to in writing, software | ||
;; distributed under the License is distributed on an "AS IS" BASIS, | ||
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
;; See the License for the specific language governing permissions and | ||
;; limitations under the License. | ||
;; | ||
|
||
(defproject objectdetector "0.1.0-SNAPSHOT" | ||
:description "Object detection using infer with MXNet" | ||
:plugins [[lein-cljfmt "0.5.7"]] | ||
:dependencies [[org.clojure/clojure "1.9.0"] | ||
[org.clojure/tools.cli "0.4.1"] | ||
[org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"]] | ||
:main ^:skip-aot infer.objectdetector-example | ||
:profiles {:uberjar {:aot :all}}) |
Oops, something went wrong.