Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Update code for integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Kedar Bellare committed Dec 22, 2018
1 parent 341fe62 commit fb7ee7f
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
(defn check-valid-file
"Check that the file exists"
[input-file]
(let [file (io/file input-file)]
(.exists file)))
(.exists (io/file input-file)))

(def cli-options
[["-m" "--model-path-prefix PREFIX" "Model path prefix"
Expand All @@ -34,18 +33,12 @@
["-d" "--input-dir IMAGE_DIR" "Input directory"
:default "images/"
:validate [check-valid-dir "Input directory not found"]]
[nil "--device [cpu|gpu]" "Device"
:default "cpu"
:validate [#(#{"cpu" "gpu"} %) "Device must be one of cpu or gpu"]]
[nil "--device-id INT" "Device ID"
:default 0]
["-h" "--help"]])

(defn print-predictions
"Print image classifier predictions for the given input file"
[input-file predictions]
[predictions]
(println (apply str (repeat 80 "=")))
(println "Input file:" input-file)
(doseq [[label probability] predictions]
(println (format "Class: %s Probability=%.8f" label probability)))
(println (apply str (repeat 80 "="))))
Expand All @@ -56,7 +49,7 @@
(let [image (infer/load-image-from-file input-image)
topk 5
[predictions] (infer/classify-image classifier image topk)]
(print-predictions input-image predictions)))
predictions))

(defn classify-images-in-dir
"Classify all jpg images in the directory"
Expand All @@ -69,32 +62,27 @@
(filter #(re-matches #".*\.jpg$" (.getPath %)))
(mapv #(.getPath %))
(partition-all batch-size))]
(doseq [image-files image-file-batches]
(let [image-batch (infer/load-image-paths image-files)
topk 5]
(doseq [[input-image preds]
(map list
image-files
(infer/classify-image-batch classifier image-batch topk))]
(print-predictions input-image preds))))))
(apply
concat
(for [image-files image-file-batches]
(let [image-batch (infer/load-image-paths image-files)
topk 5]
(infer/classify-image-batch classifier image-batch topk))))))

(defn run-classifier
"Runs an image classifier based on options provided"
[options]
(let [{:keys [model-path-prefix input-image input-dir
device device-id]} options
ctx (if (= device "cpu")
(context/cpu device-id)
(context/gpu device-id))
(let [{:keys [model-path-prefix input-image input-dir]} options
descriptors [(mx-io/data-desc {:name "data"
:shape [1 3 224 224]
:layout layout/NCHW
:dtype dtype/FLOAT32})]
factory (infer/model-factory model-path-prefix descriptors)
classifier (infer/create-image-classifier
factory {:contexts [ctx]})]
(classify-single-image classifier input-image)
(classify-images-in-dir classifier input-dir)))
factory {:contexts [(context/default-context)]})]
(print-predictions (classify-single-image classifier input-image))
(doseq [predictions (classify-images-in-dir classifier input-dir)]
(print-predictions predictions))))

(defn -main
[& args]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns infer.imageclassifier-example-test
(:require [infer.imageclassifier-example :refer [classify-single-image
classify-images-in-dir]]
[org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.infer :as infer]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.layout :as layout]
[clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
[clojure.test :refer :all]))

(def model-dir "models/")
(def image-dir "images/")
(def model-path-prefix (str model-dir "resnet-18/resnet-18"))
(def image-file (str image-dir "kitten.jpg"))

(when-not (.exists (io/file (str model-path-prefix "-symbol.json")))
(sh "./scripts/get_resnet_18_data.sh"))

(defn create-classifier []
(let [descriptors [(mx-io/data-desc {:name "data"
:shape [1 3 224 224]
:layout layout/NCHW
:dtype dtype/FLOAT32})]
factory (infer/model-factory model-path-prefix descriptors)]
(infer/create-image-classifier factory)))

(deftest test-single-classification
(let [classifier (create-classifier)
predictions (classify-single-image classifier image-file)]
(is (some? predictions))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
(is (every? #(string? (first %)) predictions))
(is (every? #(float? (second %)) predictions))
(is (every? #(< 0 (second %) 1) predictions))
(is (= ["n02123159 tiger cat"
"n02124075 Egyptian cat"
"n02123045 tabby, tabby cat"
"n02127052 lynx, catamount"
"n02128757 snow leopard, ounce, Panthera uncia"]
(map first predictions)))))

(deftest test-batch-classification
(let [classifier (create-classifier)
batch-predictions (classify-images-in-dir classifier image-dir)
predictions (first batch-predictions)]
(is (some? batch-predictions))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
(is (every? #(string? (first %)) predictions))
(is (every? #(float? (second %)) predictions))
(is (every? #(< 0 (second %) 1) predictions))))
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns infer.objectdetector-example
(:require [org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.dtype :as dtype]
Expand All @@ -20,8 +36,7 @@
(defn check-valid-file
"Check that the file exists"
[input-file]
(let [file (io/file input-file)]
(.exists file)))
(.exists (io/file input-file)))

(def cli-options
[["-m" "--model-path-prefix PREFIX" "Model path prefix"
Expand All @@ -34,18 +49,12 @@
["-d" "--input-dir IMAGE_DIR" "Input directory"
:default "images/"
:validate [check-valid-dir "Input directory not found"]]
[nil "--device [cpu|gpu]" "Device"
:default "cpu"
:validate [#(#{"cpu" "gpu"} %) "Device must be one of cpu or gpu"]]
[nil "--device-id INT" "Device ID"
:default 0]
["-h" "--help"]])

(defn print-predictions
"Print image detector predictions for the given input file"
[input-file predictions width height]
[predictions width height]
(println (apply str (repeat 80 "=")))
(println "Top detected objects for input file:" input-file)
(doseq [[label prob-and-bounds] predictions]
(println (format
"Class: %s Prob=%.5f Coords=(%.3f, %.3f, %.3f, %.3f)"
Expand All @@ -59,15 +68,15 @@

(defn detect-single-image
"Detect objects in a single image and print top-5 predictions"
[detector input-image width height]
[detector input-image]
(let [image (infer/load-image-from-file input-image)
topk 5
[predictions] (infer/detect-objects detector image topk)]
(print-predictions input-image predictions width height)))
predictions))

(defn detect-images-in-dir
"Detect objects in all jpg images in the directory"
[detector input-dir width height]
[detector input-dir]
(let [batch-size 20
image-file-batches (->> input-dir
io/file
Expand All @@ -76,32 +85,30 @@
(filter #(re-matches #".*\.jpg$" (.getPath %)))
(mapv #(.getPath %))
(partition-all batch-size))]
(doseq [image-files image-file-batches]
(let [image-batch (infer/load-image-paths image-files)
topk 5]
(doseq [[input-image preds]
(map list
image-files
(infer/detect-objects-batch detector image-batch topk))]
(print-predictions input-image preds width height))))))
(apply
concat
(for [image-files image-file-batches]
(let [image-batch (infer/load-image-paths image-files)
topk 5]
(infer/detect-objects-batch detector image-batch topk))))))

(defn run-detector
"Runs an image detector based on options provided"
[options]
(let [{:keys [model-path-prefix input-image input-dir
device device-id]} options
ctx (if (= device "cpu")
(context/cpu device-id)
(context/gpu device-id))
width 512 height 512
descriptors [(mx-io/data-desc {:name "data"
:shape [1 3 height width]
:layout layout/NCHW
:dtype dtype/FLOAT32})]
factory (infer/model-factory model-path-prefix descriptors)
detector (infer/create-object-detector factory {:contexts [ctx]})]
(detect-single-image detector input-image width height)
(detect-images-in-dir detector input-dir width height)))
detector (infer/create-object-detector
factory
{:contexts [(context/default-context)]})]
(print-predictions (detect-single-image detector input-image) width height)
(doseq [predictions (detect-images-in-dir detector input-dir)]
(print-predictions predictions width height))))

(defn -main
[& args]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns infer.objectdetector-example-test
(:require [infer.objectdetector-example :refer [detect-single-image
detect-images-in-dir]]
[org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.infer :as infer]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.layout :as layout]
[clojure.java.io :as io]
[clojure.java.shell :refer [sh]]
[clojure.test :refer :all]))

(def model-dir "models/")
(def image-dir "images/")
(def model-path-prefix (str model-dir "resnet50_ssd/resnet50_ssd_model"))
(def image-file (str image-dir "dog.jpg"))

(when-not (.exists (io/file (str model-path-prefix "-symbol.json")))
(sh "./scripts/get_ssd_data.sh"))

(defn create-detector []
(let [descriptors [(mx-io/data-desc {:name "data"
:shape [1 3 512 512]
:layout layout/NCHW
:dtype dtype/FLOAT32})]
factory (infer/model-factory model-path-prefix descriptors)]
(infer/create-object-detector factory)))

(deftest test-single-detection
(let [detector (create-detector)
predictions (detect-single-image detector image-file)]
(is (some? predictions))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
(is (every? #(string? (first %)) predictions))
(is (every? #(= 5 (count (second %))) predictions))
(is (every? #(< 0 (first (second %)) 1) predictions))
(is (= ["car" "bicycle" "dog" "bicycle" "person"]
(map first predictions)))))

(deftest test-batch-detection
(let [detector (create-detector)
batch-predictions (detect-images-in-dir detector image-dir)
predictions (first batch-predictions)]
(is (some? batch-predictions))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
(is (every? #(string? (first %)) predictions))
(is (every? #(= 5 (count (second %))) predictions))
(is (every? #(< 0 (first (second %)) 1) predictions))))
Loading

0 comments on commit fb7ee7f

Please sign in to comment.