Skip to content

Commit

Permalink
Port of scala infer package to clojure (apache#13595)
Browse files Browse the repository at this point in the history
* Port of scala infer package to clojure

* Add inference examples

* Fix project.clj

* Update code for integration tests

* Address comments and add unit tests

* Add specs and simplify interface

* Minor nit

* Update README
  • Loading branch information
kedarbellare authored and haohuw committed Jun 23, 2019
1 parent 45ab188 commit 158c67a
Show file tree
Hide file tree
Showing 32 changed files with 1,556 additions and 6 deletions.
12 changes: 12 additions & 0 deletions contrib/clojure-package/examples/infer/imageclassifier/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/target
/classes
/checkouts
/images
pom.xml
pom.xml.asc
*.jar
*.class
/.lein-*
/.nrepl-port
.hgignore
.hg/
24 changes: 24 additions & 0 deletions contrib/clojure-package/examples/infer/imageclassifier/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# imageclassifier

Run image classification using clojure infer package.

## Installation

Before you run this example, make sure that you have the clojure package installed.
In the main clojure package directory, do `lein install`. Then you can run
`lein install` in this directory.

## Usage

```
$ chmod +x scripts/get_resnet_18_data.sh
$ ./scripts/get_resnet_18_data.sh
$
$ lein run -- --help
$ lein run -- -m models/resnet-18/resnet-18 -i images/kitten.jpg -d images/
$
$ lein uberjar
$ java -jar target/imageclassifier-0.1.0-SNAPSHOT-standalone.jar --help
$ java -jar target/imageclassifier-0.1.0-SNAPSHOT-standalone.jar \
-m models/resnet-18/resnet-18 -i images/kitten.jpg -d images/
```
25 changes: 25 additions & 0 deletions contrib/clojure-package/examples/infer/imageclassifier/project.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(defproject imageclassifier "0.1.0-SNAPSHOT"
:description "Image classification using infer with MXNet"
:plugins [[lein-cljfmt "0.5.7"]]
:dependencies [[org.clojure/clojure "1.9.0"]
[org.clojure/tools.cli "0.4.1"]
[org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"]]
:main ^:skip-aot infer.imageclassifier-example
:profiles {:uberjar {:aot :all}})
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/bin/bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

set -evx

MXNET_ROOT=$(cd "$(dirname $0)/.."; pwd)

data_path=$MXNET_ROOT/models/resnet-18/

image_path=$MXNET_ROOT/images/

if [ ! -d "$data_path" ]; then
mkdir -p "$data_path"
fi

if [ ! -d "$image_path" ]; then
mkdir -p "$image_path"
fi

if [ ! -f "$data_path/resnet-18-0000.params" ]; then
wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-symbol.json -P $data_path
wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/resnet-18-0000.params -P $data_path
wget https://s3.us-east-2.amazonaws.com/scala-infer-models/resnet-18/synset.txt -P $data_path
fi

if [ ! -f "$image_path/kitten.jpg" ]; then
wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/kitten.jpg -P $image_path
wget https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg -P $image_path
fi
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/bin/bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

set -e

MXNET_ROOT=$(cd "$(dirname $0)/.."; pwd)

data_path=$MXNET_ROOT/models/resnet-152/

image_path=$MXNET_ROOT/images/

if [ ! -d "$data_path" ]; then
mkdir -p "$data_path"
fi

if [ ! -d "$image_path" ]; then
mkdir -p "$image_path"
fi

if [ ! -f "$data_path/resnet-152-0000.params" ]; then
wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/resnet-152-0000.params -P $data_path
wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/resnet-152-symbol.json -P $data_path
wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/synset.txt -P $data_path
fi

if [ ! -f "$image_path/kitten.jpg" ]; then
wget https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/resnet152/kitten.jpg -P $image_path
fi
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
(ns infer.imageclassifier-example
(:require [org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.infer :as infer]
[org.apache.clojure-mxnet.layout :as layout]
[clojure.java.io :as io]
[clojure.string :refer [join]]
[clojure.tools.cli :refer [parse-opts]])
(:gen-class))

(defn check-valid-dir
"Check that the input directory exists"
[input-dir]
(let [dir (io/file input-dir)]
(and
(.exists dir)
(.isDirectory dir))))

(defn check-valid-file
"Check that the file exists"
[input-file]
(.exists (io/file input-file)))

(def cli-options
[["-m" "--model-path-prefix PREFIX" "Model path prefix"
:default "models/resnet-18/resnet-18"
:validate [#(check-valid-file (str % "-symbol.json"))
"Model path prefix is invalid"]]
["-i" "--input-image IMAGE" "Input image"
:default "images/kitten.jpg"
:validate [check-valid-file "Input file not found"]]
["-d" "--input-dir IMAGE_DIR" "Input directory"
:default "images/"
:validate [check-valid-dir "Input directory not found"]]
["-h" "--help"]])

(defn print-predictions
"Print image classifier predictions for the given input file"
[predictions]
(println (apply str (repeat 80 "=")))
(doseq [[label probability] predictions]
(println (format "Class: %s Probability=%.8f" label probability)))
(println (apply str (repeat 80 "="))))

(defn classify-single-image
"Classify a single image and print top-5 predictions"
[classifier input-image]
(let [image (infer/load-image-from-file input-image)
topk 5
[predictions] (infer/classify-image classifier image topk)]
predictions))

(defn classify-images-in-dir
"Classify all jpg images in the directory"
[classifier input-dir]
(let [batch-size 20
image-file-batches (->> input-dir
io/file
file-seq
(filter #(.isFile %))
(filter #(re-matches #".*\.jpg$" (.getPath %)))
(mapv #(.getPath %))
(partition-all batch-size))]
(apply
concat
(for [image-files image-file-batches]
(let [image-batch (infer/load-image-paths image-files)
topk 5]
(infer/classify-image-batch classifier image-batch topk))))))

(defn run-classifier
"Runs an image classifier based on options provided"
[options]
(let [{:keys [model-path-prefix input-image input-dir]} options
descriptors [{:name "data"
:shape [1 3 224 224]
:layout layout/NCHW
:dtype dtype/FLOAT32}]
factory (infer/model-factory model-path-prefix descriptors)
classifier (infer/create-image-classifier
factory {:contexts [(context/default-context)]})]
(println "Classifying a single image")
(print-predictions (classify-single-image classifier input-image))
(println "Classifying images in a directory")
(doseq [predictions (classify-images-in-dir classifier input-dir)]
(print-predictions predictions))))

(defn -main
[& args]
(let [{:keys [options summary errors] :as opts}
(parse-opts args cli-options)]
(cond
(:help options) (println summary)
(some? errors) (println (join "\n" errors))
:else (run-classifier options))))
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns infer.imageclassifier-example-test
(:require [infer.imageclassifier-example :refer [classify-single-image
classify-images-in-dir]]
[org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.infer :as infer]
[org.apache.clojure-mxnet.layout :as layout]
[clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
[clojure.test :refer :all]))

(def model-dir "models/")
(def image-dir "images/")
(def model-path-prefix (str model-dir "resnet-18/resnet-18"))
(def image-file (str image-dir "kitten.jpg"))

(when-not (.exists (io/file (str model-path-prefix "-symbol.json")))
(sh "./scripts/get_resnet_18_data.sh"))

(defn create-classifier []
(let [descriptors [{:name "data"
:shape [1 3 224 224]
:layout layout/NCHW
:dtype dtype/FLOAT32}]
factory (infer/model-factory model-path-prefix descriptors)]
(infer/create-image-classifier factory)))

(deftest test-single-classification
(let [classifier (create-classifier)
predictions (classify-single-image classifier image-file)]
(is (some? predictions))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
(is (every? #(string? (first %)) predictions))
(is (every? #(float? (second %)) predictions))
(is (every? #(< 0 (second %) 1) predictions))
(is (= ["n02123159 tiger cat"
"n02124075 Egyptian cat"
"n02123045 tabby, tabby cat"
"n02127052 lynx, catamount"
"n02128757 snow leopard, ounce, Panthera uncia"]
(map first predictions)))))

(deftest test-batch-classification
(let [classifier (create-classifier)
batch-predictions (classify-images-in-dir classifier image-dir)
predictions (first batch-predictions)]
(is (some? batch-predictions))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
(is (every? #(string? (first %)) predictions))
(is (every? #(float? (second %)) predictions))
(is (every? #(< 0 (second %) 1) predictions))))
12 changes: 12 additions & 0 deletions contrib/clojure-package/examples/infer/objectdetector/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/target
/classes
/checkouts
/images
pom.xml
pom.xml.asc
*.jar
*.class
/.lein-*
/.nrepl-port
.hgignore
.hg/
24 changes: 24 additions & 0 deletions contrib/clojure-package/examples/infer/objectdetector/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# objectdetector

Run object detection on images using clojure infer package.

## Installation

Before you run this example, make sure that you have the clojure package installed.
In the main clojure package directory, do `lein install`. Then you can run
`lein install` in this directory.

## Usage

```
$ chmod +x scripts/get_ssd_data.sh
$ ./scripts/get_ssd_data.sh
$
$ lein run -- --help
$ lein run -- -m models/resnet50_ssd/resnet50_ssd_model -i images/dog.jpg -d images/
$
$ lein uberjar
$ java -jar target/objectdetector-0.1.0-SNAPSHOT-standalone.jar --help
$ java -jar target/objectdetector-0.1.0-SNAPSHOT-standalone.jar \
-m models/resnet50_ssd/resnet50_ssd_model -i images/dog.jpg -d images/
```
25 changes: 25 additions & 0 deletions contrib/clojure-package/examples/infer/objectdetector/project.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(defproject objectdetector "0.1.0-SNAPSHOT"
:description "Object detection using infer with MXNet"
:plugins [[lein-cljfmt "0.5.7"]]
:dependencies [[org.clojure/clojure "1.9.0"]
[org.clojure/tools.cli "0.4.1"]
[org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"]]
:main ^:skip-aot infer.objectdetector-example
:profiles {:uberjar {:aot :all}})
Loading

0 comments on commit 158c67a

Please sign in to comment.