From c2110ada6d43f10710d181c0deb0673fe6d829b2 Mon Sep 17 00:00:00 2001 From: Carin Meier Date: Sat, 12 Jan 2019 19:01:17 -0500 Subject: [PATCH] [Clojure] package infer tweaks (#13864) * change object detection prediction to be a map * change predictions to a map for image-classifiers * change return types of the classifiers to be a map - add tests for base classifier and with-ndarray as well * tweak return types and inputs for predict - add test for plain predict * updated infer-classify examples * adjust the infer/object detections tests * tweak predictor test * Feedback from @kedarbellare review * put scaling back in * put back predict so it can handle multiple inputs * restore original functions signatures (remove first) --- .../src/infer/imageclassifier_example.clj | 19 ++- .../infer/imageclassifier_example_test.clj | 25 +--- .../src/infer/objectdetector_example.clj | 25 ++-- .../infer/objectdetector_example_test.clj | 24 +-- .../predictor/src/infer/predictor_example.clj | 4 +- .../src/org/apache/clojure_mxnet/infer.clj | 137 +++++++++++------- .../infer/imageclassifier_test.clj | 96 +++++++++--- .../infer/objectdetector_test.clj | 47 ++++-- .../clojure_mxnet/infer/predictor_test.clj | 24 ++- 9 files changed, 250 insertions(+), 151 deletions(-) diff --git a/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj b/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj index 4ec7ff7f1490..6994b4fadc26 100644 --- a/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj +++ b/contrib/clojure-package/examples/infer/imageclassifier/src/infer/imageclassifier_example.clj @@ -55,8 +55,8 @@ "Print image classifier predictions for the given input file" [predictions] (println (apply str (repeat 80 "="))) - (doseq [[label probability] predictions] - (println (format "Class: %s Probability=%.8f" label probability))) + (doseq [p predictions] + (println p)) (println (apply str (repeat 80 "=")))) (defn classify-single-image @@ -64,8 +64,8 @@ [classifier input-image] (let [image (infer/load-image-from-file input-image) topk 5 - [predictions] (infer/classify-image classifier image topk)] - predictions)) + predictions (infer/classify-image classifier image topk)] + [predictions])) (defn classify-images-in-dir "Classify all jpg images in the directory" @@ -78,12 +78,10 @@ (filter #(re-matches #".*\.jpg$" (.getPath %))) (mapv #(.getPath %)) (partition-all batch-size))] - (apply - concat - (for [image-files image-file-batches] - (let [image-batch (infer/load-image-paths image-files) - topk 5] - (infer/classify-image-batch classifier image-batch topk)))))) + (apply concat (for [image-files image-file-batches] + (let [image-batch (infer/load-image-paths image-files) + topk 5] + (infer/classify-image-batch classifier image-batch topk)))))) (defn run-classifier "Runs an image classifier based on options provided" @@ -98,6 +96,7 @@ factory {:contexts [(context/default-context)]})] (println "Classifying a single image") (print-predictions (classify-single-image classifier input-image)) + (println "\n") (println "Classifying images in a directory") (doseq [predictions (classify-images-in-dir classifier input-dir)] (print-predictions predictions)))) diff --git a/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj b/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj index 5b3e08d134f8..4b71f845dd5f 100644 --- a/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj +++ b/contrib/clojure-package/examples/infer/imageclassifier/test/infer/imageclassifier_example_test.clj @@ -43,27 +43,16 @@ (deftest test-single-classification (let [classifier (create-classifier) - predictions (classify-single-image classifier image-file)] + [[predictions]] (classify-single-image classifier image-file)] (is (some? predictions)) (is (= 5 (count predictions))) - (is (every? #(= 2 (count %)) predictions)) - (is (every? #(string? (first %)) predictions)) - (is (every? #(float? (second %)) predictions)) - (is (every? #(< 0 (second %) 1) predictions)) - (is (= ["n02123159 tiger cat" - "n02124075 Egyptian cat" - "n02123045 tabby, tabby cat" - "n02127052 lynx, catamount" - "n02128757 snow leopard, ounce, Panthera uncia"] - (map first predictions))))) + (is (= "n02123159 tiger cat" (:class (first predictions)))) + (is (= (< 0 (:prob (first predictions)) 1))))) (deftest test-batch-classification (let [classifier (create-classifier) - batch-predictions (classify-images-in-dir classifier image-dir) - predictions (first batch-predictions)] - (is (some? batch-predictions)) + predictions (first (classify-images-in-dir classifier image-dir))] + (is (some? predictions)) (is (= 5 (count predictions))) - (is (every? #(= 2 (count %)) predictions)) - (is (every? #(string? (first %)) predictions)) - (is (every? #(float? (second %)) predictions)) - (is (every? #(< 0 (second %) 1) predictions)))) + (is (= "n02123159 tiger cat" (:class (first predictions)))) + (is (= (< 0 (:prob (first predictions)) 1))))) diff --git a/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj b/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj index 53172f0c8cad..5c30e5db63fe 100644 --- a/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj +++ b/contrib/clojure-package/examples/infer/objectdetector/src/infer/objectdetector_example.clj @@ -54,15 +54,15 @@ "Print image detector predictions for the given input file" [predictions width height] (println (apply str (repeat 80 "="))) - (doseq [[label prob-and-bounds] predictions] + (doseq [{:keys [class prob x-min y-min x-max y-max]} predictions] (println (format "Class: %s Prob=%.5f Coords=(%.3f, %.3f, %.3f, %.3f)" - label - (aget prob-and-bounds 0) - (* (aget prob-and-bounds 1) width) - (* (aget prob-and-bounds 2) height) - (* (aget prob-and-bounds 3) width) - (* (aget prob-and-bounds 4) height)))) + class + prob + (* x-min width) + (* y-min height) + (* x-max width) + (* y-max height)))) (println (apply str (repeat 80 "=")))) (defn detect-single-image @@ -84,12 +84,10 @@ (filter #(re-matches #".*\.jpg$" (.getPath %))) (mapv #(.getPath %)) (partition-all batch-size))] - (apply - concat - (for [image-files image-file-batches] - (let [image-batch (infer/load-image-paths image-files) - topk 5] - (infer/detect-objects-batch detector image-batch topk)))))) + (apply concat (for [image-files image-file-batches] + (let [image-batch (infer/load-image-paths image-files) + topk 5] + (infer/detect-objects-batch detector image-batch topk)))))) (defn run-detector "Runs an image detector based on options provided" @@ -107,6 +105,7 @@ {:contexts [(context/default-context)]})] (println "Object detection on a single image") (print-predictions (detect-single-image detector input-image) width height) + (println "\n") (println "Object detection on images in a directory") (doseq [predictions (detect-images-in-dir detector input-dir)] (print-predictions predictions width height)))) diff --git a/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj b/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj index 90ed02f67a73..2b8ad951ae22 100644 --- a/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj +++ b/contrib/clojure-package/examples/infer/objectdetector/test/infer/objectdetector_example_test.clj @@ -43,23 +43,23 @@ (deftest test-single-detection (let [detector (create-detector) - predictions (detect-single-image detector image-file)] + predictions (detect-single-image detector image-file) + {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)] (is (some? predictions)) (is (= 5 (count predictions))) - (is (every? #(= 2 (count %)) predictions)) - (is (every? #(string? (first %)) predictions)) - (is (every? #(= 5 (count (second %))) predictions)) - (is (every? #(< 0 (first (second %)) 1) predictions)) - (is (= ["car" "bicycle" "dog" "bicycle" "person"] - (map first predictions))))) + (is (string? class)) + (is (< 0.8 prob)) + (is (every? #(< 0 % 1) [x-min x-max y-min y-max])) + (is (= #{"dog" "person" "bicycle" "car"} (set (mapv :class predictions)))))) (deftest test-batch-detection (let [detector (create-detector) batch-predictions (detect-images-in-dir detector image-dir) - predictions (first batch-predictions)] + predictions (first batch-predictions) + {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)] (is (some? batch-predictions)) (is (= 5 (count predictions))) - (is (every? #(= 2 (count %)) predictions)) - (is (every? #(string? (first %)) predictions)) - (is (every? #(= 5 (count (second %))) predictions)) - (is (every? #(< 0 (first (second %)) 1) predictions)))) + (is (string? class)) + (is (< 0.8 prob)) + (every? #(< 0 % 1) [x-min x-max y-min y-max]) + (is (= #{"dog" "person" "bicycle" "car"} (set (mapv :class predictions)))))) diff --git a/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj b/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj index 498964128dd8..05eb0add3138 100644 --- a/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj +++ b/contrib/clojure-package/examples/infer/predictor/src/infer/predictor_example.clj @@ -59,8 +59,8 @@ (defn do-inference "Run inference using given predictor" [predictor image] - (let [[predictions] (infer/predict-with-ndarray predictor [image])] - predictions)) + (let [predictions (infer/predict-with-ndarray predictor [image])] + (first predictions))) (defn postprocess [model-path-prefix predictions] diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj index 224a39275dac..09edf15b4288 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj @@ -22,7 +22,8 @@ [org.apache.clojure-mxnet.io :as mx-io] [org.apache.clojure-mxnet.shape :as shape] [org.apache.clojure-mxnet.util :as util] - [clojure.spec.alpha :as s]) + [clojure.spec.alpha :as s] + [org.apache.clojure-mxnet.shape :as mx-shape]) (:import (java.awt.image BufferedImage) (org.apache.mxnet NDArray) (org.apache.mxnet.infer Classifier ImageClassifier @@ -39,15 +40,26 @@ (defrecord WrappedObjectDetector [object-detector]) (s/def ::ndarray #(instance? NDArray %)) -(s/def ::float-array (s/and #(.isArray (class %)) #(every? float? %))) -(s/def ::vec-of-float-arrays (s/coll-of ::float-array :kind vector?)) +(s/def ::number-array (s/coll-of number? :kind vector?)) +(s/def ::vvec-of-numbers (s/coll-of ::number-array :kind vector?)) (s/def ::vec-of-ndarrays (s/coll-of ::ndarray :kind vector?)) +(s/def ::image #(instance? BufferedImage %)) +(s/def ::batch-images (s/coll-of ::image :kind vector?)) (s/def ::wrapped-predictor (s/keys :req-un [::predictor])) (s/def ::wrapped-classifier (s/keys :req-un [::classifier])) (s/def ::wrapped-image-classifier (s/keys :req-un [::image-classifier])) (s/def ::wrapped-detector (s/keys :req-un [::object-detector])) +(defn- format-detection-predictions [predictions] + (mapv (fn [[c p]] + (let [[prob xmin ymin xmax ymax] (mapv float p)] + {:class c :prob prob :x-min xmin :y-min ymin :x-max xmax :y-max ymax})) + predictions)) + +(defn- format-classification-predictions [predictions] + (mapv (fn [[c p]] {:class c :prob p}) predictions)) + (defprotocol APredictor (predict [wrapped-predictor inputs]) (predict-with-ndarray [wrapped-predictor input-arrays])) @@ -87,19 +99,20 @@ [wrapped-predictor inputs] (util/validate! ::wrapped-predictor wrapped-predictor "Invalid predictor") - (util/validate! ::vec-of-float-arrays inputs + (util/validate! ::vvec-of-numbers inputs "Invalid inputs") - (util/coerce-return-recursive - (.predict (:predictor wrapped-predictor) - (util/vec->indexed-seq inputs)))) + (->> (.predict (:predictor wrapped-predictor) + (util/vec->indexed-seq (mapv float-array inputs))) + (util/coerce-return-recursive) + (mapv #(mapv float %)))) (predict-with-ndarray [wrapped-predictor input-arrays] (util/validate! ::wrapped-predictor wrapped-predictor "Invalid predictor") (util/validate! ::vec-of-ndarrays input-arrays "Invalid input arrays") - (util/coerce-return-recursive - (.predictWithNDArray (:predictor wrapped-predictor) - (util/vec->indexed-seq input-arrays))))) + (-> (.predictWithNDArray (:predictor wrapped-predictor) + (util/vec->indexed-seq input-arrays)) + (util/coerce-return-recursive)))) (s/def ::nil-or-int (s/nilable int?)) @@ -111,13 +124,14 @@ ([wrapped-classifier inputs topk] (util/validate! ::wrapped-classifier wrapped-classifier "Invalid classifier") - (util/validate! ::vec-of-float-arrays inputs + (util/validate! ::vvec-of-numbers inputs "Invalid inputs") (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.classify (:classifier wrapped-classifier) - (util/vec->indexed-seq inputs) - (util/->int-option topk))))) + (->> (.classify (:classifier wrapped-classifier) + (util/vec->indexed-seq (mapv float-array inputs)) + (util/->int-option topk)) + (util/coerce-return-recursive) + (format-classification-predictions)))) (classify-with-ndarray ([wrapped-classifier inputs] (classify-with-ndarray wrapped-classifier inputs nil)) @@ -127,10 +141,11 @@ (util/validate! ::vec-of-ndarrays inputs "Invalid inputs") (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.classifyWithNDArray (:classifier wrapped-classifier) - (util/vec->indexed-seq inputs) - (util/->int-option topk))))) + (->> (.classifyWithNDArray (:classifier wrapped-classifier) + (util/vec->indexed-seq inputs) + (util/->int-option topk)) + (util/coerce-return-recursive) + (mapv format-classification-predictions)))) WrappedImageClassifier (classify ([wrapped-image-classifier inputs] @@ -138,13 +153,14 @@ ([wrapped-image-classifier inputs topk] (util/validate! ::wrapped-image-classifier wrapped-image-classifier "Invalid classifier") - (util/validate! ::vec-of-float-arrays inputs + (util/validate! ::vvec-of-numbers inputs "Invalid inputs") (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.classify (:image-classifier wrapped-image-classifier) - (util/vec->indexed-seq inputs) - (util/->int-option topk))))) + (->> (.classify (:image-classifier wrapped-image-classifier) + (util/vec->indexed-seq (mapv float-array inputs)) + (util/->int-option topk)) + (util/coerce-return-recursive) + (format-classification-predictions)))) (classify-with-ndarray ([wrapped-image-classifier inputs] (classify-with-ndarray wrapped-image-classifier inputs nil)) @@ -154,10 +170,11 @@ (util/validate! ::vec-of-ndarrays inputs "Invalid inputs") (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.classifyWithNDArray (:image-classifier wrapped-image-classifier) - (util/vec->indexed-seq inputs) - (util/->int-option topk)))))) + (->> (.classifyWithNDArray (:image-classifier wrapped-image-classifier) + (util/vec->indexed-seq inputs) + (util/->int-option topk)) + (util/coerce-return-recursive) + (mapv format-classification-predictions))))) (s/def ::image #(instance? BufferedImage %)) (s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64}) @@ -175,11 +192,12 @@ (util/validate! ::image image "Invalid image") (util/validate! ::nil-or-int topk "Invalid top-K") (util/validate! ::dtype dtype "Invalid dtype") - (util/coerce-return-recursive - (.classifyImage (:image-classifier wrapped-image-classifier) - image - (util/->int-option topk) - dtype)))) + (->> (.classifyImage (:image-classifier wrapped-image-classifier) + image + (util/->int-option topk) + dtype) + (util/coerce-return-recursive) + (mapv format-classification-predictions)))) (classify-image-batch ([wrapped-image-classifier images] (classify-image-batch wrapped-image-classifier images nil dtype/FLOAT32)) @@ -188,13 +206,15 @@ ([wrapped-image-classifier images topk dtype] (util/validate! ::wrapped-image-classifier wrapped-image-classifier "Invalid classifier") + (util/validate! ::batch-images images "Invalid Batch Images") (util/validate! ::nil-or-int topk "Invalid top-K") (util/validate! ::dtype dtype "Invalid dtype") - (util/coerce-return-recursive - (.classifyImageBatch (:image-classifier wrapped-image-classifier) - images - (util/->int-option topk) - dtype))))) + (->> (.classifyImageBatch (:image-classifier wrapped-image-classifier) + (util/vec->indexed-seq images) + (util/->int-option topk) + dtype) + (util/coerce-return-recursive) + (mapv format-classification-predictions))))) (extend-protocol AObjectDetector WrappedObjectDetector @@ -206,10 +226,11 @@ "Invalid object detector") (util/validate! ::image image "Invalid image") (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.imageObjectDetect (:object-detector wrapped-detector) - image - (util/->int-option topk))))) + (->> (.imageObjectDetect (:object-detector wrapped-detector) + image + (util/->int-option topk)) + (util/coerce-return-recursive) + (mapv format-detection-predictions)))) (detect-objects-batch ([wrapped-detector images] (detect-objects-batch wrapped-detector images nil)) @@ -217,10 +238,12 @@ (util/validate! ::wrapped-detector wrapped-detector "Invalid object detector") (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.imageBatchObjectDetect (:object-detector wrapped-detector) - images - (util/->int-option topk))))) + (util/validate! ::batch-images images "Invalid Batch Images") + (->> (.imageBatchObjectDetect (:object-detector wrapped-detector) + (util/vec->indexed-seq images) + (util/->int-option topk)) + (util/coerce-return-recursive) + (mapv format-detection-predictions)))) (detect-objects-with-ndarrays ([wrapped-detector input-arrays] (detect-objects-with-ndarrays wrapped-detector input-arrays nil)) @@ -230,10 +253,11 @@ (util/validate! ::vec-of-ndarrays input-arrays "Invalid inputs") (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.objectDetectWithNDArray (:object-detector wrapped-detector) - (util/vec->indexed-seq input-arrays) - (util/->int-option topk)))))) + (->> (.objectDetectWithNDArray (:object-detector wrapped-detector) + (util/vec->indexed-seq input-arrays) + (util/->int-option topk)) + (util/coerce-return-recursive) + (mapv format-detection-predictions))))) (defprotocol AInferenceFactory (create-predictor [factory] [factory opts]) @@ -324,10 +348,12 @@ (defn buffered-image-to-pixels "Convert input BufferedImage to NDArray of input shape" - [image input-shape-vec] - (util/validate! ::image image "Invalid image") - (util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector") - (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec) dtype/FLOAT32)) + ([image input-shape-vec] + (buffered-image-to-pixels image input-shape-vec dtype/FLOAT32)) + ([image input-shape-vec dtype] + (util/validate! ::image image "Invalid image") + (util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector") + (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec) dtype))) (s/def ::image-path string?) (s/def ::image-paths (s/coll-of ::image-path)) @@ -342,4 +368,5 @@ "Loads images from a list of file names" [image-paths] (util/validate! ::image-paths image-paths "Invalid image paths") - (ImageClassifier/loadInputBatch (util/convert-vector image-paths))) + (util/scala-vector->vec + (ImageClassifier/loadInputBatch (util/convert-vector image-paths)))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj index b459b06132b2..e3935c31e342 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj @@ -19,6 +19,7 @@ [org.apache.clojure-mxnet.dtype :as dtype] [org.apache.clojure-mxnet.infer :as infer] [org.apache.clojure-mxnet.layout :as layout] + [org.apache.clojure-mxnet.ndarray :as ndarray] [clojure.java.io :as io] [clojure.java.shell :refer [sh]] [clojure.test :refer :all])) @@ -45,32 +46,83 @@ [predictions] (infer/classify-image classifier image 5 dtype/FLOAT32)] (is (= 1000 (count predictions-all))) (is (= 10 (count predictions-with-default-dtype))) - (is (some? predictions)) (is (= 5 (count predictions))) - (is (every? #(= 2 (count %)) predictions)) - (is (every? #(string? (first %)) predictions)) - (is (every? #(float? (second %)) predictions)) - (is (every? #(< 0 (second %) 1) predictions)) - (is (= ["n02123159 tiger cat" - "n02124075 Egyptian cat" - "n02123045 tabby, tabby cat" - "n02127052 lynx, catamount" - "n02128757 snow leopard, ounce, Panthera uncia"] - (map first predictions))))) + (is (= "n02123159 tiger cat" (:class (first predictions)))) + (is (= (< 0 (:prob (first predictions)) 1))))) (deftest test-batch-classification (let [classifier (create-classifier) image-batch (infer/load-image-paths ["test/test-images/kitten.jpg" "test/test-images/Pug-Cookie.jpg"]) - batch-predictions-all (infer/classify-image-batch classifier image-batch) - batch-predictions-with-default-dtype (infer/classify-image-batch classifier image-batch 10) - batch-predictions (infer/classify-image-batch classifier image-batch 5 dtype/FLOAT32) - predictions (first batch-predictions)] - (is (= 1000 (count (first batch-predictions-all)))) - (is (= 10 (count (first batch-predictions-with-default-dtype)))) - (is (some? batch-predictions)) + [batch-predictions-all] (infer/classify-image-batch classifier image-batch) + [batch-predictions-with-default-dtype] (infer/classify-image-batch classifier image-batch 10) + [predictions] (infer/classify-image-batch classifier image-batch 5 dtype/FLOAT32)] + (is (= 1000 (count batch-predictions-all))) + (is (= 10 (count batch-predictions-with-default-dtype))) (is (= 5 (count predictions))) - (is (every? #(= 2 (count %)) predictions)) - (is (every? #(string? (first %)) predictions)) - (is (every? #(float? (second %)) predictions)) - (is (every? #(< 0 (second %) 1) predictions)))) + (is (= "n02123159 tiger cat" (:class (first predictions)))) + (is (= (< 0 (:prob (first predictions)) 1))))) + +(deftest test-single-classification-with-ndarray + (let [classifier (create-classifier) + image (-> (infer/load-image-from-file "test/test-images/kitten.jpg") + (infer/reshape-image 224 224) + (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32) + (ndarray/expand-dims 0)) + [predictions-all] (infer/classify-with-ndarray classifier [image]) + [predictions] (infer/classify-with-ndarray classifier [image] 5)] + (is (= 1000 (count predictions-all))) + (is (= 5 (count predictions))) + (is (= "n02123159 tiger cat" (:class (first predictions)))) + (is (= (< 0 (:prob (first predictions)) 1))))) + +(deftest test-single-classify + (let [classifier (create-classifier) + image (-> (infer/load-image-from-file "test/test-images/kitten.jpg") + (infer/reshape-image 224 224) + (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32) + (ndarray/expand-dims 0)) + predictions-all (infer/classify classifier [(ndarray/->vec image)]) + predictions (infer/classify classifier [(ndarray/->vec image)] 5)] + (is (= 1000 (count predictions-all))) + (is (= 5 (count predictions))) + (is (= "n02123159 tiger cat" (:class (first predictions)))) + (is (= (< 0 (:prob (first predictions)) 1))))) + +(deftest test-base-classification-with-ndarray + (let [descriptors [{:name "data" + :shape [1 3 224 224] + :layout layout/NCHW + :dtype dtype/FLOAT32}] + factory (infer/model-factory model-path-prefix descriptors) + classifier (infer/create-classifier factory) + image (-> (infer/load-image-from-file "test/test-images/kitten.jpg") + (infer/reshape-image 224 224) + (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32) + (ndarray/expand-dims 0)) + [predictions-all] (infer/classify-with-ndarray classifier [image]) + [predictions] (infer/classify-with-ndarray classifier [image] 5)] + (is (= 1000 (count predictions-all))) + (is (= 5 (count predictions))) + (is (= "n02123159 tiger cat" (:class (first predictions)))) + (is (= (< 0 (:prob (first predictions)) 1))))) + +(deftest test-base-single-classify + (let [descriptors [{:name "data" + :shape [1 3 224 224] + :layout layout/NCHW + :dtype dtype/FLOAT32}] + factory (infer/model-factory model-path-prefix descriptors) + classifier (infer/create-classifier factory) + image (-> (infer/load-image-from-file "test/test-images/kitten.jpg") + (infer/reshape-image 224 224) + (infer/buffered-image-to-pixels [3 224 224] dtype/FLOAT32) + (ndarray/expand-dims 0)) + predictions-all (infer/classify classifier [(ndarray/->vec image)]) + predictions (infer/classify classifier [(ndarray/->vec image)] 5)] + (is (= 1000 (count predictions-all))) + (is (= 5 (count predictions))) + (is (= "n02123159 tiger cat" (:class (first predictions)))) + (is (= (< 0 (:prob (first predictions)) 1))))) + + diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj index 3a0e3d30a1d9..e2b9579c7000 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj @@ -21,7 +21,8 @@ [org.apache.clojure-mxnet.layout :as layout] [clojure.java.io :as io] [clojure.java.shell :refer [sh]] - [clojure.test :refer :all])) + [clojure.test :refer :all] + [org.apache.clojure-mxnet.ndarray :as ndarray])) (def model-dir "data/") (def model-path-prefix (str model-dir "resnet50_ssd/resnet50_ssd_model")) @@ -41,27 +42,41 @@ (let [detector (create-detector) image (infer/load-image-from-file "test/test-images/kitten.jpg") [predictions-all] (infer/detect-objects detector image) - [predictions] (infer/detect-objects detector image 5)] + [predictions] (infer/detect-objects detector image 5) + {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)] (is (some? predictions)) (is (= 5 (count predictions))) (is (= 13 (count predictions-all))) - (is (every? #(= 2 (count %)) predictions)) - (is (every? #(string? (first %)) predictions)) - (is (every? #(= 5 (count (second %))) predictions)) - (is (every? #(< 0 (first (second %)) 1) predictions)) - (is (= "cat" (first (first predictions)))))) + (is (= "cat" class)) + (is (< 0.8 prob)) + (every? #(< 0 % 1) [x-min x-max y-min y-max]))) (deftest test-batch-detection (let [detector (create-detector) image-batch (infer/load-image-paths ["test/test-images/kitten.jpg" "test/test-images/Pug-Cookie.jpg"]) - batch-predictions-all (infer/detect-objects-batch detector image-batch) - batch-predictions (infer/detect-objects-batch detector image-batch 5) - predictions (first batch-predictions)] - (is (some? batch-predictions)) - (is (= 13 (count (first batch-predictions-all)))) + [batch-predictions-all] (infer/detect-objects-batch detector image-batch) + [predictions] (infer/detect-objects-batch detector image-batch 5) + {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)] + (is (some? predictions)) + (is (= 13 (count batch-predictions-all))) (is (= 5 (count predictions))) - (is (every? #(= 2 (count %)) predictions)) - (is (every? #(string? (first %)) predictions)) - (is (every? #(= 5 (count (second %))) predictions)) - (is (every? #(< 0 (first (second %)) 1) predictions)))) + (is (= "cat" class)) + (is (< 0.8 prob)) + (every? #(< 0 % 1) [x-min x-max y-min y-max]))) + +(deftest test-detection-with-ndarrays + (let [detector (create-detector) + image (-> (infer/load-image-from-file "test/test-images/kitten.jpg") + (infer/reshape-image 512 512) + (infer/buffered-image-to-pixels [3 512 512] dtype/FLOAT32) + (ndarray/expand-dims 0)) + [predictions-all] (infer/detect-objects-with-ndarrays detector [image]) + [predictions] (infer/detect-objects-with-ndarrays detector [image] 1) + {:keys [class prob x-min x-max y-min y-max] :as pred} (first predictions)] + (is (some? predictions-all)) + (is (= 1 (count predictions))) + (is (= "cat" class)) + (is (< 0.8 prob)) + (every? #(< 0 % 1) [x-min x-max y-min y-max]))) + diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj index 0e7532bc2258..e1526be61fbf 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/predictor_test.clj @@ -24,7 +24,8 @@ [clojure.java.io :as io] [clojure.java.shell :refer [sh]] [clojure.string :refer [split]] - [clojure.test :refer :all])) + [clojure.test :refer :all] + [org.apache.clojure-mxnet.util :as util])) (def model-dir "data/") (def model-path-prefix (str model-dir "resnet-18/resnet-18")) @@ -42,6 +43,22 @@ factory (infer/model-factory model-path-prefix descriptors)] (infer/create-predictor factory))) +(deftest predictor-test-with-ndarray + (let [predictor (create-predictor) + image-ndarray (-> "test/test-images/kitten.jpg" + infer/load-image-from-file + (infer/reshape-image width height) + (infer/buffered-image-to-pixels [3 width height]) + (ndarray/expand-dims 0)) + predictions (infer/predict-with-ndarray predictor [image-ndarray]) + synset-file (-> (io/file model-path-prefix) + (.getParent) + (io/file "synset.txt")) + synset-names (split (slurp synset-file) #"\n") + [best-index] (ndarray/->int-vec (ndarray/argmax (first predictions) 1)) + best-prediction (synset-names best-index)] + (is (= "n02123159 tiger cat" best-prediction)))) + (deftest predictor-test (let [predictor (create-predictor) image-ndarray (-> "test/test-images/kitten.jpg" @@ -49,11 +66,12 @@ (infer/reshape-image width height) (infer/buffered-image-to-pixels [3 width height]) (ndarray/expand-dims 0)) - [predictions] (infer/predict-with-ndarray predictor [image-ndarray]) + predictions (infer/predict predictor [(ndarray/->vec image-ndarray)]) synset-file (-> (io/file model-path-prefix) (.getParent) (io/file "synset.txt")) synset-names (split (slurp synset-file) #"\n") - [best-index] (ndarray/->int-vec (ndarray/argmax predictions 1)) + ndarray-preds (ndarray/array (first predictions) [1 1000]) + [best-index] (ndarray/->int-vec (ndarray/argmax ndarray-preds 1)) best-prediction (synset-names best-index)] (is (= "n02123159 tiger cat" best-prediction))))