diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index b9f84d592a70..5b5fdce712f1 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -193,6 +193,7 @@ List of Contributors * [Yuxi Hu](https://github.com/yuxihu) * [Harsh Patel](https://github.com/harshp8l) * [Xiao Wang](https://github.com/BeyonderXX) +* [Piyush Ghai](https://github.com/piyushghai) Label Bot --------- diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj index b2b23da6274e..224a39275dac 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj @@ -18,6 +18,7 @@ (ns org.apache.clojure-mxnet.infer (:refer-clojure :exclude [type]) (:require [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.dtype :as dtype] [org.apache.clojure-mxnet.io :as mx-io] [org.apache.clojure-mxnet.shape :as shape] [org.apache.clojure-mxnet.util :as util] @@ -62,10 +63,12 @@ (defprotocol AImageClassifier (classify-image [wrapped-image-classifier image] - [wrapped-image-classifier image topk]) + [wrapped-image-classifier image topk] + [wrapped-image-classifier image topk dtype]) (classify-image-batch [wrapped-image-classifier images] - [wrapped-image-classifier images topk])) + [wrapped-image-classifier images topk] + [wrapped-image-classifier images topk dtype])) (defprotocol AObjectDetector (detect-objects @@ -80,7 +83,8 @@ (extend-protocol APredictor WrappedPredictor - (predict [wrapped-predictor inputs] + (predict + [wrapped-predictor inputs] (util/validate! ::wrapped-predictor wrapped-predictor "Invalid predictor") (util/validate! ::vec-of-float-arrays inputs @@ -101,62 +105,50 @@ (extend-protocol AClassifier WrappedClassifier - (classify [wrapped-classifier inputs] - (util/validate! ::wrapped-classifier wrapped-classifier - "Invalid classifier") - (util/validate! ::vec-of-float-arrays inputs - "Invalid inputs") - (classify wrapped-classifier inputs nil)) - (classify [wrapped-classifier inputs topk] - (util/validate! ::wrapped-classifier wrapped-classifier - "Invalid classifier") - (util/validate! ::vec-of-float-arrays inputs - "Invalid inputs") - (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.classify (:classifier wrapped-classifier) - (util/vec->indexed-seq inputs) - (util/->int-option topk)))) - (classify-with-ndarray [wrapped-classifier inputs] - (util/validate! ::wrapped-classifier wrapped-classifier - "Invalid classifier") - (util/validate! ::vec-of-ndarrays inputs - "Invalid inputs") - (classify-with-ndarray wrapped-classifier inputs nil)) - (classify-with-ndarray [wrapped-classifier inputs topk] - (util/validate! ::wrapped-classifier wrapped-classifier - "Invalid classifier") - (util/validate! ::vec-of-ndarrays inputs - "Invalid inputs") - (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.classifyWithNDArray (:classifier wrapped-classifier) - (util/vec->indexed-seq inputs) - (util/->int-option topk)))) + (classify + ([wrapped-classifier inputs] + (classify wrapped-classifier inputs nil)) + ([wrapped-classifier inputs topk] + (util/validate! ::wrapped-classifier wrapped-classifier + "Invalid classifier") + (util/validate! ::vec-of-float-arrays inputs + "Invalid inputs") + (util/validate! ::nil-or-int topk "Invalid top-K") + (util/coerce-return-recursive + (.classify (:classifier wrapped-classifier) + (util/vec->indexed-seq inputs) + (util/->int-option topk))))) + (classify-with-ndarray + ([wrapped-classifier inputs] + (classify-with-ndarray wrapped-classifier inputs nil)) + ([wrapped-classifier inputs topk] + (util/validate! ::wrapped-classifier wrapped-classifier + "Invalid classifier") + (util/validate! ::vec-of-ndarrays inputs + "Invalid inputs") + (util/validate! ::nil-or-int topk "Invalid top-K") + (util/coerce-return-recursive + (.classifyWithNDArray (:classifier wrapped-classifier) + (util/vec->indexed-seq inputs) + (util/->int-option topk))))) WrappedImageClassifier - (classify [wrapped-image-classifier inputs] - (util/validate! ::wrapped-image-classifier wrapped-image-classifier - "Invalid classifier") - (util/validate! ::vec-of-float-arrays inputs - "Invalid inputs") - (classify wrapped-image-classifier inputs nil)) - (classify [wrapped-image-classifier inputs topk] - (util/validate! ::wrapped-image-classifier wrapped-image-classifier - "Invalid classifier") - (util/validate! ::vec-of-float-arrays inputs - "Invalid inputs") - (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.classify (:image-classifier wrapped-image-classifier) - (util/vec->indexed-seq inputs) - (util/->int-option topk)))) - (classify-with-ndarray [wrapped-image-classifier inputs] - (util/validate! ::wrapped-image-classifier wrapped-image-classifier - "Invalid classifier") - (util/validate! ::vec-of-ndarrays inputs - "Invalid inputs") - (classify-with-ndarray wrapped-image-classifier inputs nil)) - (classify-with-ndarray [wrapped-image-classifier inputs topk] + (classify + ([wrapped-image-classifier inputs] + (classify wrapped-image-classifier inputs nil)) + ([wrapped-image-classifier inputs topk] + (util/validate! ::wrapped-image-classifier wrapped-image-classifier + "Invalid classifier") + (util/validate! ::vec-of-float-arrays inputs + "Invalid inputs") + (util/validate! ::nil-or-int topk "Invalid top-K") + (util/coerce-return-recursive + (.classify (:image-classifier wrapped-image-classifier) + (util/vec->indexed-seq inputs) + (util/->int-option topk))))) + (classify-with-ndarray + ([wrapped-image-classifier inputs] + (classify-with-ndarray wrapped-image-classifier inputs nil)) + ([wrapped-image-classifier inputs topk] (util/validate! ::wrapped-image-classifier wrapped-image-classifier "Invalid classifier") (util/validate! ::vec-of-ndarrays inputs @@ -165,83 +157,83 @@ (util/coerce-return-recursive (.classifyWithNDArray (:image-classifier wrapped-image-classifier) (util/vec->indexed-seq inputs) - (util/->int-option topk))))) + (util/->int-option topk)))))) (s/def ::image #(instance? BufferedImage %)) +(s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64}) (extend-protocol AImageClassifier WrappedImageClassifier - (classify-image [wrapped-image-classifier image] - (util/validate! ::wrapped-image-classifier wrapped-image-classifier - "Invalid classifier") - (util/validate! ::image image "Invalid image") - (classify-image wrapped-image-classifier image nil)) - (classify-image [wrapped-image-classifier image topk] - (util/validate! ::wrapped-image-classifier wrapped-image-classifier - "Invalid classifier") - (util/validate! ::image image "Invalid image") - (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.classifyImage (:image-classifier wrapped-image-classifier) - image - (util/->int-option topk)))) - (classify-image-batch [wrapped-image-classifier images] - (util/validate! ::wrapped-image-classifier wrapped-image-classifier - "Invalid classifier") - (classify-image-batch wrapped-image-classifier images nil)) - (classify-image-batch [wrapped-image-classifier images topk] - (util/validate! ::wrapped-image-classifier wrapped-image-classifier - "Invalid classifier") - (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.classifyImageBatch (:image-classifier wrapped-image-classifier) - images - (util/->int-option topk))))) + (classify-image + ([wrapped-image-classifier image] + (classify-image wrapped-image-classifier image nil dtype/FLOAT32)) + ([wrapped-image-classifier image topk] + (classify-image wrapped-image-classifier image topk dtype/FLOAT32)) + ([wrapped-image-classifier image topk dtype] + (util/validate! ::wrapped-image-classifier wrapped-image-classifier + "Invalid classifier") + (util/validate! ::image image "Invalid image") + (util/validate! ::nil-or-int topk "Invalid top-K") + (util/validate! ::dtype dtype "Invalid dtype") + (util/coerce-return-recursive + (.classifyImage (:image-classifier wrapped-image-classifier) + image + (util/->int-option topk) + dtype)))) + (classify-image-batch + ([wrapped-image-classifier images] + (classify-image-batch wrapped-image-classifier images nil dtype/FLOAT32)) + ([wrapped-image-classifier images topk] + (classify-image-batch wrapped-image-classifier images topk dtype/FLOAT32)) + ([wrapped-image-classifier images topk dtype] + (util/validate! ::wrapped-image-classifier wrapped-image-classifier + "Invalid classifier") + (util/validate! ::nil-or-int topk "Invalid top-K") + (util/validate! ::dtype dtype "Invalid dtype") + (util/coerce-return-recursive + (.classifyImageBatch (:image-classifier wrapped-image-classifier) + images + (util/->int-option topk) + dtype))))) (extend-protocol AObjectDetector WrappedObjectDetector - (detect-objects [wrapped-detector image] - (util/validate! ::wrapped-detector wrapped-detector - "Invalid object detector") - (util/validate! ::image image "Invalid image") - (detect-objects wrapped-detector image nil)) - (detect-objects [wrapped-detector image topk] - (util/validate! ::wrapped-detector wrapped-detector - "Invalid object detector") - (util/validate! ::image image "Invalid image") - (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.imageObjectDetect (:object-detector wrapped-detector) - image - (util/->int-option topk)))) - (detect-objects-batch [wrapped-detector images] - (util/validate! ::wrapped-detector wrapped-detector - "Invalid object detector") - (detect-objects-batch wrapped-detector images nil)) - (detect-objects-batch [wrapped-detector images topk] - (util/validate! ::wrapped-detector wrapped-detector - "Invalid object detector") - (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.imageBatchObjectDetect (:object-detector wrapped-detector) - images - (util/->int-option topk)))) - (detect-objects-with-ndarrays [wrapped-detector input-arrays] - (util/validate! ::wrapped-detector wrapped-detector - "Invalid object detector") - (util/validate! ::vec-of-ndarrays input-arrays - "Invalid inputs") - (detect-objects-with-ndarrays wrapped-detector input-arrays nil)) - (detect-objects-with-ndarrays [wrapped-detector input-arrays topk] + (detect-objects + ([wrapped-detector image] + (detect-objects wrapped-detector image nil)) + ([wrapped-detector image topk] (util/validate! ::wrapped-detector wrapped-detector "Invalid object detector") - (util/validate! ::vec-of-ndarrays input-arrays - "Invalid inputs") - (util/validate! ::nil-or-int topk "Invalid top-K") - (util/coerce-return-recursive - (.objectDetectWithNDArray (:object-detector wrapped-detector) - (util/vec->indexed-seq input-arrays) + (util/validate! ::image image "Invalid image") + (util/validate! ::nil-or-int topk "Invalid top-K") + (util/coerce-return-recursive + (.imageObjectDetect (:object-detector wrapped-detector) + image + (util/->int-option topk))))) + (detect-objects-batch + ([wrapped-detector images] + (detect-objects-batch wrapped-detector images nil)) + ([wrapped-detector images topk] + (util/validate! ::wrapped-detector wrapped-detector + "Invalid object detector") + (util/validate! ::nil-or-int topk "Invalid top-K") + (util/coerce-return-recursive + (.imageBatchObjectDetect (:object-detector wrapped-detector) + images (util/->int-option topk))))) + (detect-objects-with-ndarrays + ([wrapped-detector input-arrays] + (detect-objects-with-ndarrays wrapped-detector input-arrays nil)) + ([wrapped-detector input-arrays topk] + (util/validate! ::wrapped-detector wrapped-detector + "Invalid object detector") + (util/validate! ::vec-of-ndarrays input-arrays + "Invalid inputs") + (util/validate! ::nil-or-int topk "Invalid top-K") + (util/coerce-return-recursive + (.objectDetectWithNDArray (:object-detector wrapped-detector) + (util/vec->indexed-seq input-arrays) + (util/->int-option topk)))))) (defprotocol AInferenceFactory (create-predictor [factory] [factory opts]) @@ -335,7 +327,7 @@ [image input-shape-vec] (util/validate! ::image image "Invalid image") (util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector") - (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec))) + (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec) dtype/FLOAT32)) (s/def ::image-path string?) (s/def ::image-paths (s/coll-of ::image-path)) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj new file mode 100644 index 000000000000..0967df2289d8 --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj @@ -0,0 +1,46 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.primitives + (:import (org.apache.mxnet MX_PRIMITIVES$MX_FLOAT MX_PRIMITIVES$MX_Double + MX_PRIMITIVES$MX_PRIMITIVE_TYPE))) + + +;;; Defines customer mx primitives that can be used for mathematical computations +;;; in NDArrays to control precision. Currently Float and Double are supported + +;;; For purposes of automatic conversion in ndarray functions, doubles are default +;; to specify using floats you must use a Float + +(defn mx-float + "Creates a MXNet float primitive" + [num] + (new MX_PRIMITIVES$MX_FLOAT num)) + +(defn mx-double + "Creates a MXNet double primitive" + [num] + (new MX_PRIMITIVES$MX_Double num)) + +(defn ->num + "Returns the underlying number value" + [primitive] + (.data primitive)) + +(defn primitive? [x] + (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE x)) + diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj index 21e31baa3a9b..43970c0abd79 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj @@ -19,6 +19,7 @@ (:require [clojure.spec.alpha :as s] [t6.from-scala.core :refer [$ $$] :as $] [clojure.string :as string] + [org.apache.clojure-mxnet.primitives :as primitives] [org.apache.clojure-mxnet.shape :as mx-shape]) (:import (org.apache.mxnet NDArray) (scala Product Tuple2 Tuple3) @@ -36,7 +37,8 @@ "byte<>" "byte-array" "java.lang.String<>" "vec-or-strings" "org.apache.mxnet.NDArray" "ndarray" - "org.apache.mxnet.Symbol" "sym"}) + "org.apache.mxnet.Symbol" "sym" + "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE" "double-or-float"}) (def symbol-param-coerce {"java.lang.String" "sym-name" "float" "num" @@ -144,6 +146,8 @@ (and (get targets "int<>") (vector? param)) (int-array param) (and (get targets "float<>") (vector? param)) (float-array param) (and (get targets "java.lang.String<>") (vector? param)) (into-array param) + (and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (instance? Float param)) (primitives/mx-float param) + (and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (number? param)) (primitives/mx-double param) :else param)) (defn nil-or-coerce-param [param targets] @@ -177,6 +181,7 @@ (instance? Map return-val) (scala-map->map return-val) (instance? Tuple2 return-val) (tuple->vec return-val) (instance? Tuple3 return-val) (tuple->vec return-val) + (primitives/primitive? return-val) (primitives/->num return-val) :else return-val)) (defn coerce-return-recursive [return-val] diff --git a/contrib/clojure-package/test/good-test-ndarray.clj b/contrib/clojure-package/test/good-test-ndarray.clj index 3b53b1906006..b048a819c642 100644 --- a/contrib/clojure-package/test/good-test-ndarray.clj +++ b/contrib/clojure-package/test/good-test-ndarray.clj @@ -27,11 +27,12 @@ (defn div - ([ndarray num-or-ndarray] + ([ndarray ndarray-or-double-or-float] (util/coerce-return (.$div ndarray (util/coerce-param - num-or-ndarray - #{"float" "org.apache.mxnet.NDArray"}))))) + ndarray-or-double-or-float + #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE" + "org.apache.mxnet.NDArray"}))))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj index 9badfed933a5..b459b06132b2 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj @@ -40,7 +40,11 @@ (deftest test-single-classification (let [classifier (create-classifier) image (infer/load-image-from-file "test/test-images/kitten.jpg") - [predictions] (infer/classify-image classifier image 5)] + [predictions-all] (infer/classify-image classifier image) + [predictions-with-default-dtype] (infer/classify-image classifier image 10) + [predictions] (infer/classify-image classifier image 5 dtype/FLOAT32)] + (is (= 1000 (count predictions-all))) + (is (= 10 (count predictions-with-default-dtype))) (is (some? predictions)) (is (= 5 (count predictions))) (is (every? #(= 2 (count %)) predictions)) @@ -58,8 +62,12 @@ (let [classifier (create-classifier) image-batch (infer/load-image-paths ["test/test-images/kitten.jpg" "test/test-images/Pug-Cookie.jpg"]) - batch-predictions (infer/classify-image-batch classifier image-batch 5) + batch-predictions-all (infer/classify-image-batch classifier image-batch) + batch-predictions-with-default-dtype (infer/classify-image-batch classifier image-batch 10) + batch-predictions (infer/classify-image-batch classifier image-batch 5 dtype/FLOAT32) predictions (first batch-predictions)] + (is (= 1000 (count (first batch-predictions-all)))) + (is (= 10 (count (first batch-predictions-with-default-dtype)))) (is (some? batch-predictions)) (is (= 5 (count predictions))) (is (every? #(= 2 (count %)) predictions)) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj index 788a59491095..3a0e3d30a1d9 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj @@ -40,9 +40,11 @@ (deftest test-single-detection (let [detector (create-detector) image (infer/load-image-from-file "test/test-images/kitten.jpg") + [predictions-all] (infer/detect-objects detector image) [predictions] (infer/detect-objects detector image 5)] (is (some? predictions)) (is (= 5 (count predictions))) + (is (= 13 (count predictions-all))) (is (every? #(= 2 (count %)) predictions)) (is (every? #(string? (first %)) predictions)) (is (every? #(= 5 (count (second %))) predictions)) @@ -53,9 +55,11 @@ (let [detector (create-detector) image-batch (infer/load-image-paths ["test/test-images/kitten.jpg" "test/test-images/Pug-Cookie.jpg"]) + batch-predictions-all (infer/detect-objects-batch detector image-batch) batch-predictions (infer/detect-objects-batch detector image-batch 5) predictions (first batch-predictions)] (is (some? batch-predictions)) + (is (= 13 (count (first batch-predictions-all)))) (is (= 5 (count predictions))) (is (every? #(= 2 (count %)) predictions)) (is (every? #(string? (first %)) predictions)) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj index 79e94412d0df..9ffd3abed2f9 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj @@ -97,7 +97,7 @@ (is (= [1.0 1.0] (->vec ndhalves))))) (deftest test-full - (let [nda (full [1 2] 3)] + (let [nda (full [1 2] 3.0)] (is (= (shape nda) (mx-shape/->shape [1 2]))) (is (= [3.0 3.0] (->vec nda))))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj new file mode 100644 index 000000000000..1a538e537b8b --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj @@ -0,0 +1,45 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.primitives-test + (:require [org.apache.clojure-mxnet.primitives :as primitives] + [clojure.test :refer :all]) + (:import (org.apache.mxnet MX_PRIMITIVES$MX_PRIMITIVE_TYPE + MX_PRIMITIVES$MX_FLOAT + MX_PRIMITIVES$MX_Double))) + +(deftest test-primitive-types + (is (not (primitives/primitive? 3))) + (is (primitives/primitive? (primitives/mx-float 3))) + (is (primitives/primitive? (primitives/mx-double 3)))) + +(deftest test-float-primitives + (is (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE (primitives/mx-float 3))) + (is (instance? MX_PRIMITIVES$MX_FLOAT (primitives/mx-float 3))) + (is (instance? Float (-> (primitives/mx-float 3) + (primitives/->num)))) + (is (= 3.0 (-> (primitives/mx-float 3) + (primitives/->num))))) + +(deftest test-double-primitives + (is (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE (primitives/mx-double 2))) + (is (instance? MX_PRIMITIVES$MX_Double (primitives/mx-double 2))) + (is (instance? Double (-> (primitives/mx-double 2) + (primitives/->num)))) + (is (= 2.0 (-> (primitives/mx-double 2) + (primitives/->num))))) + diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj index bd77a8a0edc6..c26f83d5aa49 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj @@ -20,6 +20,7 @@ [org.apache.clojure-mxnet.shape :as mx-shape] [org.apache.clojure-mxnet.util :as util] [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.primitives :as primitives] [org.apache.clojure-mxnet.symbol :as sym] [org.apache.clojure-mxnet.test-util :as test-util] [clojure.spec.alpha :as s]) @@ -133,6 +134,9 @@ (is (= "[F" (->> (util/coerce-param [1 2] #{"float<>"}) str (take 2) (apply str)))) (is (= "[L" (->> (util/coerce-param [1 2] #{"java.lang.String<>"}) str (take 2) (apply str)))) + (is (primitives/primitive? (util/coerce-param 1.0 #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"}))) + (is (primitives/primitive? (util/coerce-param (float 1.0) #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"}))) + (is (= 1 (util/coerce-param 1 #{"unknown"})))) (deftest test-nil-or-coerce-param @@ -171,6 +175,12 @@ (util/convert-tuple [1 2])))) (is (= [1 2 3] (util/coerce-return (util/convert-tuple [1 2 3])))) + + (is (instance? Double (util/coerce-return (primitives/mx-double 3)))) + (is (= 3.0 (util/coerce-return (primitives/mx-double 3)))) + (is (instance? Float (util/coerce-return (primitives/mx-float 2)))) + (is (= 2.0 (util/coerce-return (primitives/mx-float 2)))) + (is (= "foo" (util/coerce-return "foo")))) (deftest test-translate-keyword-shape diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala index ed7aff602f63..001bd04d2c95 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala @@ -18,7 +18,9 @@ package org.apache.mxnet import org.apache.mxnet.util.NativeLibraryLoader -import org.slf4j.{LoggerFactory, Logger} +import org.slf4j.{Logger, LoggerFactory} + +import scala.Specializable.Group private[mxnet] object Base { private val logger: Logger = LoggerFactory.getLogger("MXNetJVM") @@ -57,6 +59,9 @@ private[mxnet] object Base { val MX_REAL_TYPE = DType.Float32 + // The primitives currently supported for NDArray operations + val MX_PRIMITIVES = new Group ((Double, Float)) + try { try { tryLoadLibraryOS("mxnet-scala") diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala index 0a5683aa7ab3..20b6ed9fc806 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala @@ -93,6 +93,9 @@ private[mxnet] class LibInfo { @native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle, source: Array[MXFloat], size: Int): Int + @native def mxFloat64NDArraySyncCopyFromCPU(handle: NDArrayHandle, + source: Array[Double], + size: Int): Int @native def mxNDArrayLoad(fname: String, outSize: MXUintRef, handles: ArrayBuffer[NDArrayHandle], diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala new file mode 100644 index 000000000000..cb978856963c --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +object MX_PRIMITIVES { + + /** + * This defines the basic primitives we can use in Scala for mathematical + * computations in NDArrays.This gives us a flexibility to expand to + * more supported primitives in the future. Currently Float and Double + * are supported. The functions which accept MX_PRIMITIVE_TYPE as input can also accept + * plain old Float and Double data as inputs because of the underlying + * implicit conversion between primitives to MX_PRIMITIVE_TYPE. + */ + trait MX_PRIMITIVE_TYPE extends Ordered[MX_PRIMITIVE_TYPE]{ + + def toString: String + + def unary_- : MX_PRIMITIVE_TYPE + } + + trait MXPrimitiveOrdering extends Ordering[MX_PRIMITIVE_TYPE] { + + def compare(x: MX_PRIMITIVE_TYPE, y: MX_PRIMITIVE_TYPE): Int = x.compare(y) + + } + + implicit object MX_PRIMITIVE_TYPE extends MXPrimitiveOrdering + + /** + * Wrapper over Float in Scala. + * @param data + */ + class MX_FLOAT(val data: Float) extends MX_PRIMITIVE_TYPE { + + override def toString: String = data.toString + + override def unary_- : MX_PRIMITIVE_TYPE = new MX_FLOAT(data.unary_-) + + override def compare(that: MX_PRIMITIVE_TYPE): Int = { + this.data.compareTo(that.asInstanceOf[MX_FLOAT].data) + } + } + + implicit def FloatToMX_Float(d : Float): MX_FLOAT = new MX_FLOAT(d) + + implicit def MX_FloatToFloat(d: MX_FLOAT) : Float = d.data + + implicit def IntToMX_Float(d: Int): MX_FLOAT = new MX_FLOAT(d.toFloat) + + /** + * Wrapper over Double in Scala. + * @param data + */ + class MX_Double(val data: Double) extends MX_PRIMITIVE_TYPE { + + override def toString: String = data.toString + + override def unary_- : MX_PRIMITIVE_TYPE = new MX_Double(data.unary_-) + + override def compare(that: MX_PRIMITIVE_TYPE): Int = { + this.data.compareTo(that.asInstanceOf[MX_Double].data) + } + } + + implicit def DoubleToMX_Double(d : Double): MX_Double = new MX_Double(d) + + implicit def MX_DoubleToDouble(d: MX_Double) : Double = d.data + +} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 125958150b72..163ed2682532 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -21,6 +21,7 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.mxnet.Base._ import org.apache.mxnet.DType.DType +import org.apache.mxnet.MX_PRIMITIVES.{MX_PRIMITIVE_TYPE} import org.slf4j.LoggerFactory import scala.collection.mutable @@ -262,16 +263,46 @@ object NDArray extends NDArrayBase { arr } - // Perform power operator + def full(shape: Shape, value: Double, ctx: Context): NDArray = { + val arr = empty(shape, ctx, DType.Float64) + arr.set(value) + arr + } + + /** + * Create a new NDArray filled with given value, with specified shape. + * @param shape shape of the NDArray. + * @param value value to be filled with + */ + def full(shape: Shape, value: Double): NDArray = { + full(shape, value, null) + } + + + /** + * Perform power operation on NDArray. Returns result as NDArray + * @param lhs + * @param rhs + */ def power(lhs: NDArray, rhs: NDArray): NDArray = { NDArray.genericNDArrayFunctionInvoke("_power", Seq(lhs, rhs)) } - def power(lhs: NDArray, rhs: Float): NDArray = { + /** + * Perform scalar power operation on NDArray. Returns result as NDArray + * @param lhs NDArray on which to perform the operation on. + * @param rhs The scalar input. Can be of type Float/Double + */ + def power(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_power_scalar", Seq(lhs, rhs)) } - def power(lhs: Float, rhs: NDArray): NDArray = { + /** + * Perform scalar power operation on NDArray. Returns result as NDArray + * @param lhs The scalar input. Can be of type Float/Double + * @param rhs NDArray on which to perform the operation on. + */ + def power(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = { NDArray.genericNDArrayFunctionInvoke("_rpower_scalar", Seq(lhs, rhs)) } @@ -280,11 +311,21 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("_maximum", Seq(lhs, rhs)) } - def maximum(lhs: NDArray, rhs: Float): NDArray = { + /** + * Perform the max operation on NDArray. Returns the result as NDArray. + * @param lhs NDArray on which to perform the operation on. + * @param rhs The scalar input. Can be of type Float/Double + */ + def maximum(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs)) } - def maximum(lhs: Float, rhs: NDArray): NDArray = { + /** + * Perform the max operation on NDArray. Returns the result as NDArray. + * @param lhs The scalar input. Can be of type Float/Double + * @param rhs NDArray on which to perform the operation on. + */ + def maximum(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = { NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs)) } @@ -293,11 +334,21 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("_minimum", Seq(lhs, rhs)) } - def minimum(lhs: NDArray, rhs: Float): NDArray = { + /** + * Perform the min operation on NDArray. Returns the result as NDArray. + * @param lhs NDArray on which to perform the operation on. + * @param rhs The scalar input. Can be of type Float/Double + */ + def minimum(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs)) } - def minimum(lhs: Float, rhs: NDArray): NDArray = { + /** + * Perform the min operation on NDArray. Returns the result as NDArray. + * @param lhs The scalar input. Can be of type Float/Double + * @param rhs NDArray on which to perform the operation on. + */ + def minimum(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = { NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs)) } @@ -310,7 +361,15 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("broadcast_equal", Seq(lhs, rhs)) } - def equal(lhs: NDArray, rhs: Float): NDArray = { + /** + * Returns the result of element-wise **equal to** (==) comparison operation with broadcasting. + * For each element in input arrays, return 1(true) if corresponding elements are same, + * otherwise return 0(false). + * + * @param lhs NDArray + * @param rhs The scalar input. Can be of type Float/Double + */ + def equal(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_equal_scalar", Seq(lhs, rhs)) } @@ -324,7 +383,15 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("broadcast_not_equal", Seq(lhs, rhs)) } - def notEqual(lhs: NDArray, rhs: Float): NDArray = { + /** + * Returns the result of element-wise **not equal to** (!=) comparison operation + * with broadcasting. + * For each element in input arrays, return 1(true) if corresponding elements are different, + * otherwise return 0(false). + * @param lhs NDArray + * @param rhs The scalar input. Can be of type Float/Double + */ + def notEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_not_equal_scalar", Seq(lhs, rhs)) } @@ -338,7 +405,16 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("broadcast_greater", Seq(lhs, rhs)) } - def greater(lhs: NDArray, rhs: Float): NDArray = { + /** + * Returns the result of element-wise **greater than** (>) comparison operation + * with broadcasting. + * For each element in input arrays, return 1(true) if lhs elements are greater than rhs, + * otherwise return 0(false). + * + * @param lhs NDArray + * @param rhs The scalar input. Can be of type Float/Double + */ + def greater(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_greater_scalar", Seq(lhs, rhs)) } @@ -352,7 +428,16 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("broadcast_greater_equal", Seq(lhs, rhs)) } - def greaterEqual(lhs: NDArray, rhs: Float): NDArray = { + /** + * Returns the result of element-wise **greater than or equal to** (>=) comparison + * operation with broadcasting. + * For each element in input arrays, return 1(true) if lhs elements are greater than equal to + * rhs, otherwise return 0(false). + * + * @param lhs NDArray + * @param rhs The scalar input. Can be of type Float/Double + */ + def greaterEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_greater_equal_scalar", Seq(lhs, rhs)) } @@ -366,7 +451,15 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("broadcast_lesser", Seq(lhs, rhs)) } - def lesser(lhs: NDArray, rhs: Float): NDArray = { + /** + * Returns the result of element-wise **lesser than** (<) comparison operation + * with broadcasting. + * For each element in input arrays, return 1(true) if lhs elements are less than rhs, + * otherwise return 0(false). + * @param lhs NDArray + * @param rhs The scalar input. Can be of type Float/Double + */ + def lesser(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_lesser_scalar", Seq(lhs, rhs)) } @@ -380,7 +473,16 @@ object NDArray extends NDArrayBase { NDArray.genericNDArrayFunctionInvoke("broadcast_lesser_equal", Seq(lhs, rhs)) } - def lesserEqual(lhs: NDArray, rhs: Float): NDArray = { + /** + * Returns the result of element-wise **lesser than or equal to** (<=) comparison + * operation with broadcasting. + * For each element in input arrays, return 1(true) if lhs elements are + * lesser than equal to rhs, otherwise return 0(false). + * + * @param lhs NDArray + * @param rhs The scalar input. Can be of type Float/Double + */ + def lesserEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_lesser_equal_scalar", Seq(lhs, rhs)) } @@ -397,6 +499,16 @@ object NDArray extends NDArrayBase { arr } + def array(sourceArr: Array[Double], shape: Shape, ctx: Context): NDArray = { + val arr = empty(shape, ctx, dtype = DType.Float64) + arr.set(sourceArr) + arr + } + + def array(sourceArr: Array[Double], shape: Shape): NDArray = { + array(sourceArr, shape, null) + } + /** * Returns evenly spaced values within a given interval. * Values are generated within the half-open interval [`start`, `stop`). In other @@ -645,6 +757,12 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, checkCall(_LIB.mxNDArraySyncCopyFromCPU(handle, source, source.length)) } + private def syncCopyfrom(source: Array[Double]): Unit = { + require(source.length == size, + s"array size (${source.length}) do not match the size of NDArray ($size)") + checkCall(_LIB.mxFloat64NDArraySyncCopyFromCPU(handle, source, source.length)) + } + /** * Return a sliced NDArray that shares memory with current one. * NDArray only support continuous slicing on axis 0 @@ -759,7 +877,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, * @param value Value to set * @return Current NDArray */ - def set(value: Float): NDArray = { + def set(value: MX_PRIMITIVE_TYPE): NDArray = { require(writable, "trying to assign to a readonly NDArray") NDArray.genericNDArrayFunctionInvoke("_set_value", Seq(value), Map("out" -> this)) this @@ -776,11 +894,17 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, this } + def set(other: Array[Double]): NDArray = { + require(writable, "trying to assign to a readonly NDArray") + syncCopyfrom(other) + this + } + def +(other: NDArray): NDArray = { NDArray.genericNDArrayFunctionInvoke("_plus", Seq(this, other)) } - def +(other: Float): NDArray = { + def +(other: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_plus_scalar", Seq(this, other)) } @@ -792,7 +916,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, this } - def +=(other: Float): NDArray = { + def +=(other: MX_PRIMITIVE_TYPE): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to add to a readonly NDArray") } @@ -804,7 +928,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, NDArray.genericNDArrayFunctionInvoke("_minus", Seq(this, other)) } - def -(other: Float): NDArray = { + def -(other: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_minus_scalar", Seq(this, other)) } @@ -816,7 +940,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, this } - def -=(other: Float): NDArray = { + def -=(other: MX_PRIMITIVE_TYPE): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to subtract from a readonly NDArray") } @@ -828,7 +952,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, NDArray.genericNDArrayFunctionInvoke("_mul", Seq(this, other)) } - def *(other: Float): NDArray = { + def *(other: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_mul_scalar", Seq(this, other)) } @@ -844,7 +968,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, this } - def *=(other: Float): NDArray = { + def *=(other: MX_PRIMITIVE_TYPE): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to multiply to a readonly NDArray") } @@ -856,7 +980,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, NDArray.genericNDArrayFunctionInvoke("_div", Seq(this, other)) } - def /(other: Float): NDArray = { + def /(other: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_div_scalar", Seq(this, other)) } @@ -868,7 +992,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, this } - def /=(other: Float): NDArray = { + def /=(other: MX_PRIMITIVE_TYPE): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to divide from a readonly NDArray") } @@ -880,7 +1004,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, NDArray.power(this, other) } - def **(other: Float): NDArray = { + def **(other: MX_PRIMITIVE_TYPE): NDArray = { NDArray.power(this, other) } @@ -888,7 +1012,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, NDArray.genericNDArrayFunctionInvoke("_power", Seq(this, other), Map("out" -> this)) } - def **=(other: Float): NDArray = { + def **=(other: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_power_scalar", Seq(this, other), Map("out" -> this)) } @@ -896,7 +1020,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, NDArray.greater(this, other) } - def >(other: Float): NDArray = { + def >(other: MX_PRIMITIVE_TYPE): NDArray = { NDArray.greater(this, other) } @@ -904,7 +1028,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, NDArray.greaterEqual(this, other) } - def >=(other: Float): NDArray = { + def >=(other: MX_PRIMITIVE_TYPE): NDArray = { NDArray.greaterEqual(this, other) } @@ -912,7 +1036,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, NDArray.lesser(this, other) } - def <(other: Float): NDArray = { + def <(other: MX_PRIMITIVE_TYPE): NDArray = { NDArray.lesser(this, other) } @@ -920,7 +1044,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, NDArray.lesserEqual(this, other) } - def <=(other: Float): NDArray = { + def <=(other: MX_PRIMITIVE_TYPE): NDArray = { NDArray.lesserEqual(this, other) } @@ -928,7 +1052,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, NDArray.genericNDArrayFunctionInvoke("_mod", Seq(this, other)) } - def %(other: Float): NDArray = { + def %(other: MX_PRIMITIVE_TYPE): NDArray = { NDArray.genericNDArrayFunctionInvoke("_mod_scalar", Seq(this, other)) } @@ -940,7 +1064,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, this } - def %=(other: Float): NDArray = { + def %=(other: MX_PRIMITIVE_TYPE): NDArray = { if (!writable) { throw new IllegalArgumentException("trying to take modulo from a readonly NDArray") } @@ -956,6 +1080,14 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, internal.toFloatArray } + /** + * Return a copied flat java array of current array (row-major) with datatype as Float64/Double. + * @return A copy of array content. + */ + def toFloat64Array: Array[Double] = { + internal.toDoubleArray + } + def internal: NDArrayInternal = { val myType = dtype val arrLength = DType.numOfBytes(myType) * size @@ -975,6 +1107,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, this.toArray(0) } + def toFloat64Scalar: Double = { + require(shape == Shape(1), "The current array is not a scalar") + this.toFloat64Array(0) + } + /** * Copy the content of current array to other. * @@ -997,7 +1134,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, * @return The copy target NDArray */ def copyTo(ctx: Context): NDArray = { - val ret = new NDArray(NDArray.newAllocHandle(shape, ctx, delayAlloc = true)) + val ret = new NDArray(NDArray.newAllocHandle(shape, ctx, delayAlloc = true, dtype = dtype)) copyTo(ret) } @@ -1047,11 +1184,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, private[mxnet] object NDArrayConversions { implicit def int2Scalar(x: Int): NDArrayConversions = new NDArrayConversions(x.toFloat) - implicit def double2Scalar(x: Double): NDArrayConversions = new NDArrayConversions(x.toFloat) + implicit def double2Scalar(x: Double): NDArrayConversions = new NDArrayConversions(x) implicit def float2Scalar(x: Float): NDArrayConversions = new NDArrayConversions(x) } -private[mxnet] class NDArrayConversions(val value: Float) { +private[mxnet] class NDArrayConversions(val value: MX_PRIMITIVE_TYPE) { def +(other: NDArray): NDArray = { other + value } @@ -1145,34 +1282,39 @@ private[mxnet] class NDArrayFuncReturn(private[mxnet] val arr: Array[NDArray]) { def waitToRead(): Unit = head.waitToRead() def context: Context = head.context def set(value: Float): NDArray = head.set(value) + def set(value: Double): NDArray = head.set(value) def set(other: NDArray): NDArray = head.set(other) def set(other: Array[Float]): NDArray = head.set(other) + def set(other: Array[Double]): NDArray = head.set(other) def +(other: NDArray): NDArray = head + other - def +(other: Float): NDArray = head + other + def +(other: MX_PRIMITIVE_TYPE): NDArray = head + other def +=(other: NDArray): NDArray = head += other - def +=(other: Float): NDArray = head += other + def +=(other: MX_PRIMITIVE_TYPE): NDArray = head += other def -(other: NDArray): NDArray = head - other - def -(other: Float): NDArray = head - other + def -(other: MX_PRIMITIVE_TYPE): NDArray = head - other def -=(other: NDArray): NDArray = head -= other - def -=(other: Float): NDArray = head -= other + def -=(other: MX_PRIMITIVE_TYPE): NDArray = head -= other def *(other: NDArray): NDArray = head * other - def *(other: Float): NDArray = head * other + def *(other: MX_PRIMITIVE_TYPE): NDArray = head * other def unary_-(): NDArray = -head def *=(other: NDArray): NDArray = head *= other - def *=(other: Float): NDArray = head *= other + def *=(other: MX_PRIMITIVE_TYPE): NDArray = head *= other def /(other: NDArray): NDArray = head / other + def /(other: MX_PRIMITIVE_TYPE): NDArray = head / other def **(other: NDArray): NDArray = head ** other - def **(other: Float): NDArray = head ** other + def **(other: MX_PRIMITIVE_TYPE): NDArray = head ** other def >(other: NDArray): NDArray = head > other - def >(other: Float): NDArray = head > other + def >(other: MX_PRIMITIVE_TYPE): NDArray = head > other def >=(other: NDArray): NDArray = head >= other - def >=(other: Float): NDArray = head >= other + def >=(other: MX_PRIMITIVE_TYPE): NDArray = head >= other def <(other: NDArray): NDArray = head < other - def <(other: Float): NDArray = head < other + def <(other: MX_PRIMITIVE_TYPE): NDArray = head < other def <=(other: NDArray): NDArray = head <= other - def <=(other: Float): NDArray = head <= other + def <=(other: MX_PRIMITIVE_TYPE): NDArray = head <= other def toArray: Array[Float] = head.toArray + def toFloat64Array: Array[Double] = head.toFloat64Array def toScalar: Float = head.toScalar + def toFloat64Scalar: Double = head.toFloat64Scalar def copyTo(other: NDArray): NDArray = head.copyTo(other) def copyTo(ctx: Context): NDArray = head.copyTo(ctx) def copy(): NDArray = head.copy() diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index a84bd106b763..e30098c3088b 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -53,9 +53,9 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, val label = currentBatch.label(0) // properties val res = ( - // TODO: need to allow user to specify DType and Layout - IndexedSeq(new DataDesc(dataName, data.shape, DType.Float32, Layout.UNDEFINED)), - IndexedSeq(new DataDesc(labelName, label.shape, DType.Float32, Layout.UNDEFINED)), + // TODO: need to allow user to specify Layout + IndexedSeq(new DataDesc(dataName, data.shape, data.dtype, Layout.UNDEFINED)), + IndexedSeq(new DataDesc(labelName, label.shape, label.dtype, Layout.UNDEFINED)), ListMap(dataName -> data.shape), ListMap(labelName -> label.shape), data.shape(0)) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala index 0032a54dd802..e690abba0d13 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala @@ -61,7 +61,8 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)], dataBatchSize: Int = 1, shuffle: Boolean = false, lastBatchHandle: String = "pad", dataName: String = "data", labelName: String = "label") { - this(IO.initDataDesc(data, allowEmpty = false, dataName, MX_REAL_TYPE, Layout.UNDEFINED), + this(IO.initDataDesc(data, allowEmpty = false, dataName, + if (data == null || data.isEmpty) MX_REAL_TYPE else data(0).dtype, Layout.UNDEFINED), IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE, Layout.UNDEFINED), dataBatchSize, shuffle, lastBatchHandle) } @@ -272,7 +273,7 @@ object NDArrayIter { */ def addData(name: String, data: NDArray): Builder = { this.data = this.data ++ IndexedSeq((new DataDesc(name, - data.shape, DType.Float32, Layout.UNDEFINED), data)) + data.shape, data.dtype, Layout.UNDEFINED), data)) this } @@ -284,7 +285,7 @@ object NDArrayIter { */ def addLabel(name: String, label: NDArray): Builder = { this.label = this.label ++ IndexedSeq((new DataDesc(name, - label.shape, DType.Float32, Layout.UNDEFINED), label)) + label.shape, label.dtype, Layout.UNDEFINED), label)) this } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala index 198102d2377f..67809c158aff 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala @@ -91,17 +91,26 @@ object NDArray extends NDArrayBase { def full(shape: Shape, value: Float, ctx: Context): NDArray = org.apache.mxnet.NDArray.full(shape, value, ctx) + def full(shape: Shape, value: Double, ctx: Context): NDArray + = org.apache.mxnet.NDArray.full(shape, value, ctx) + def power(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs) def power(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs) def power(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs) + def power(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs) + def power(lhs: Double, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs) def maximum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs) def maximum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs) def maximum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs) + def maximum(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs) + def maximum(lhs: Double, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs) def minimum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs) def minimum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs) def minimum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs) + def minimum(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs) + def minimum(lhs: Double, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs) /** @@ -111,6 +120,7 @@ object NDArray extends NDArrayBase { */ def equal(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs) def equal(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs) + def equal(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs) /** * Returns the result of element-wise **not equal to** (!=) comparison operation @@ -120,6 +130,7 @@ object NDArray extends NDArrayBase { */ def notEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs) def notEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs) + def notEqual(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs) /** * Returns the result of element-wise **greater than** (>) comparison operation @@ -129,6 +140,7 @@ object NDArray extends NDArrayBase { */ def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs) def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs) + def greater(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs) /** * Returns the result of element-wise **greater than or equal to** (>=) comparison @@ -140,6 +152,8 @@ object NDArray extends NDArrayBase { = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs) def greaterEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs) + def greaterEqual(lhs: NDArray, rhs: Double): NDArray + = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs) /** * Returns the result of element-wise **lesser than** (<) comparison operation @@ -149,6 +163,7 @@ object NDArray extends NDArrayBase { */ def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs) def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs) + def lesser(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs) /** * Returns the result of element-wise **lesser than or equal to** (<=) comparison @@ -160,6 +175,8 @@ object NDArray extends NDArrayBase { = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs) def lesserEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs) + def lesserEqual(lhs: NDArray, rhs: Double): NDArray + = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs) /** * Create a new NDArray that copies content from source_array. @@ -172,6 +189,18 @@ object NDArray extends NDArrayBase { = org.apache.mxnet.NDArray.array( sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx) + /** + * Create a new NDArray that copies content from source_array. + * @param sourceArr Source data (list of Doubles) to create NDArray from. + * @param shape shape of the NDArray + * @param ctx The context of the NDArray, default to current default context. + * @return The created NDArray. + */ + def arrayWithDouble(sourceArr: java.util.List[java.lang.Double], shape: Shape, + ctx: Context = null): NDArray + = org.apache.mxnet.NDArray.array( + sourceArr.asScala.map(ele => Double.unbox(ele)).toArray, shape) + /** * Returns evenly spaced values within a given interval. * Values are generated within the half-open interval [`start`, `stop`). In other @@ -205,6 +234,10 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) { this(org.apache.mxnet.NDArray.array(arr, shape, ctx)) } + def this(arr: Array[Double], shape: Shape, ctx: Context) = { + this(org.apache.mxnet.NDArray.array(arr, shape, ctx)) + } + def this(arr: java.util.List[java.lang.Float], shape: Shape, ctx: Context) = { this(NDArray.array(arr, shape, ctx)) } @@ -304,41 +337,59 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) { * @return Current NDArray */ def set(value: Float): NDArray = nd.set(value) + def set(value: Double): NDArray = nd.set(value) def set(other: NDArray): NDArray = nd.set(other) def set(other: Array[Float]): NDArray = nd.set(other) + def set(other: Array[Double]): NDArray = nd.set(other) def add(other: NDArray): NDArray = this.nd + other.nd def add(other: Float): NDArray = this.nd + other + def add(other: Double): NDArray = this.nd + other def addInplace(other: NDArray): NDArray = this.nd += other def addInplace(other: Float): NDArray = this.nd += other + def addInplace(other: Double): NDArray = this.nd += other def subtract(other: NDArray): NDArray = this.nd - other def subtract(other: Float): NDArray = this.nd - other + def subtract(other: Double): NDArray = this.nd - other def subtractInplace(other: NDArray): NDArray = this.nd -= other def subtractInplace(other: Float): NDArray = this.nd -= other + def subtractInplace(other: Double): NDArray = this.nd -= other def multiply(other: NDArray): NDArray = this.nd * other def multiply(other: Float): NDArray = this.nd * other + def multiply(other: Double): NDArray = this.nd * other def multiplyInplace(other: NDArray): NDArray = this.nd *= other def multiplyInplace(other: Float): NDArray = this.nd *= other + def multiplyInplace(other: Double): NDArray = this.nd *= other def div(other: NDArray): NDArray = this.nd / other def div(other: Float): NDArray = this.nd / other + def div(other: Double): NDArray = this.nd / other def divInplace(other: NDArray): NDArray = this.nd /= other def divInplace(other: Float): NDArray = this.nd /= other + def divInplace(other: Double): NDArray = this.nd /= other def pow(other: NDArray): NDArray = this.nd ** other def pow(other: Float): NDArray = this.nd ** other + def pow(other: Double): NDArray = this.nd ** other def powInplace(other: NDArray): NDArray = this.nd **= other def powInplace(other: Float): NDArray = this.nd **= other + def powInplace(other: Double): NDArray = this.nd **= other def mod(other: NDArray): NDArray = this.nd % other def mod(other: Float): NDArray = this.nd % other + def mod(other: Double): NDArray = this.nd % other def modInplace(other: NDArray): NDArray = this.nd %= other def modInplace(other: Float): NDArray = this.nd %= other + def modInplace(other: Double): NDArray = this.nd %= other def greater(other: NDArray): NDArray = this.nd > other def greater(other: Float): NDArray = this.nd > other + def greater(other: Double): NDArray = this.nd > other def greaterEqual(other: NDArray): NDArray = this.nd >= other def greaterEqual(other: Float): NDArray = this.nd >= other + def greaterEqual(other: Double): NDArray = this.nd >= other def lesser(other: NDArray): NDArray = this.nd < other def lesser(other: Float): NDArray = this.nd < other + def lesser(other: Double): NDArray = this.nd < other def lesserEqual(other: NDArray): NDArray = this.nd <= other def lesserEqual(other: Float): NDArray = this.nd <= other + def lesserEqual(other: Double): NDArray = this.nd <= other /** * Return a copied flat java array of current array (row-major). @@ -346,6 +397,12 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) { */ def toArray: Array[Float] = nd.toArray + /** + * Return a copied flat java array of current array (row-major). + * @return A copy of array content. + */ + def toFloat64Array: Array[Double] = nd.toFloat64Array + /** * Return a CPU scalar(float) of current ndarray. * This ndarray must have shape (1,) @@ -354,6 +411,14 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) { */ def toScalar: Float = nd.toScalar + /** + * Return a CPU scalar(float) of current ndarray. + * This ndarray must have shape (1,) + * + * @return The scalar representation of the ndarray. + */ + def toFloat64Scalar: Double = nd.toFloat64Scalar + /** * Copy the content of current array to other. * diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java index 2659b7848bc6..86c7eb29d2ef 100644 --- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java +++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java @@ -40,6 +40,15 @@ public void testCreateNDArray() { new Shape(new int[]{1, 3}), new Context("cpu", 0)); assertTrue(Arrays.equals(nd.shape().toArray(), arr)); + + List list2 = Arrays.asList(1d, 1d, 1d); + nd = NDArray.arrayWithDouble(list2, + new Shape(new int[]{1, 3}), + new Context("cpu", 0)); + + // Float64 assertion + assertTrue(nd.dtype() == DType.Float64()); + } @Test @@ -64,6 +73,12 @@ public void testComparison(){ nd = nd.subtract(nd2); float[] lesser = new float[]{0, 0, 0}; assertTrue(Arrays.equals(nd.greater(nd2).toArray(), lesser)); + + NDArray nd3 = new NDArray(new double[]{1.0, 2.0, 3.0}, new Shape(new int[]{3}), new Context("cpu", 0)); + nd3 = nd3.add(1.0); + double[] smaller = new double[] {2, 3, 4}; + assertTrue(Arrays.equals(smaller, nd3.toFloat64Array())); + } @Test diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala index 2ec6f668dbcc..d3969b0ce77d 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala @@ -303,5 +303,32 @@ class IOSuite extends FunSuite with BeforeAndAfterAll { assert(dataDesc(0).layout == Layout.NTC) assert(labelDesc(0).dtype == DType.Int32) assert(labelDesc(0).layout == Layout.NT) + + + // Test with passing Float64 hardcoded as Dtype of data + val dataIter4 = new NDArrayIter( + IO.initDataDesc(data, false, "data", DType.Float64, Layout.NTC), + IO.initDataDesc(label, false, "label", DType.Int32, Layout.NT), + 128, false, "pad") + val dataDesc4 = dataIter4.provideDataDesc + val labelDesc4 = dataIter4.provideLabelDesc + assert(dataDesc4(0).dtype == DType.Float64) + assert(dataDesc4(0).layout == Layout.NTC) + assert(labelDesc4(0).dtype == DType.Int32) + assert(labelDesc4(0).layout == Layout.NT) + + // Test with Float64 coming from the data itself + val dataF64 = IndexedSeq(NDArray.ones(shape0, dtype = DType.Float64), + NDArray.zeros(shape0, dtype = DType.Float64)) + + val dataIter5 = new NDArrayIter( + IO.initDataDesc(dataF64, false, "data", DType.Float64, Layout.NTC), + IO.initDataDesc(label, false, "label", DType.Int32, Layout.NT), + 128, false, "pad") + val dataDesc5 = dataIter5.provideDataDesc + assert(dataDesc5(0).dtype == DType.Float64) + assert(dataDesc5(0).dtype != DType.Float32) + assert(dataDesc5(0).layout == Layout.NTC) + } } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index 2f3b1676d272..bc7a0a026bc3 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -21,7 +21,7 @@ import java.io.File import java.util.concurrent.atomic.AtomicInteger import org.apache.mxnet.NDArrayConversions._ -import org.scalatest.{Matchers, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { private val sequence: AtomicInteger = new AtomicInteger(0) @@ -29,6 +29,9 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { test("to java array") { val ndarray = NDArray.zeros(2, 2) assert(ndarray.toArray === Array(0f, 0f, 0f, 0f)) + + val float64Array = NDArray.zeros(Shape(2, 2), dtype = DType.Float64) + assert(float64Array.toFloat64Array === Array(0d, 0d, 0d, 0d)) } test("to scalar") { @@ -38,8 +41,17 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { assert(ndones.toScalar === 1f) } + test("to float 64 scalar") { + val ndzeros = NDArray.zeros(Shape(1), dtype = DType.Float64) + assert(ndzeros.toFloat64Scalar === 0d) + val ndones = NDArray.ones(Shape(1), dtype = DType.Float64) + assert(ndones.toFloat64Scalar === 1d) + } + test ("call toScalar on an ndarray which is not a scalar") { intercept[Exception] { NDArray.zeros(1, 1).toScalar } + intercept[Exception] { NDArray.zeros(shape = Shape (1, 1), + dtype = DType.Float64).toFloat64Scalar } } test("size and shape") { @@ -51,12 +63,20 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { test("dtype") { val arr = NDArray.zeros(3, 2) assert(arr.dtype === DType.Float32) + + val float64Array = NDArray.zeros(shape = Shape(3, 2), dtype = DType.Float64) + assert(float64Array.dtype === DType.Float64) } test("set scalar value") { val ndarray = NDArray.empty(2, 1) ndarray.set(10f) assert(ndarray.toArray === Array(10f, 10f)) + + val float64array = NDArray.empty(shape = Shape(2, 1), dtype = DType.Float64) + float64array.set(10d) + assert(float64array.toFloat64Array === Array(10d, 10d)) + } test("copy from java array") { @@ -66,19 +86,29 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { } test("plus") { - val ndzeros = NDArray.zeros(2, 1) - val ndones = ndzeros + 1f + var ndzeros = NDArray.zeros(2, 1) + var ndones = ndzeros + 1f assert(ndones.toArray === Array(1f, 1f)) assert((ndones + ndzeros).toArray === Array(1f, 1f)) assert((1 + ndones).toArray === Array(2f, 2f)) // in-place ndones += ndones assert(ndones.toArray === Array(2f, 2f)) + + // Float64 method test + ndzeros = NDArray.zeros(shape = Shape(2, 1), dtype = DType.Float64) + ndones = ndzeros + 1d + assert(ndones.toFloat64Array === Array(1d, 1d)) + assert((ndones + ndzeros).toFloat64Array === Array(1d, 1d)) + assert((1d + ndones).toArray === Array(2d, 2d)) + // in-place + ndones += ndones + assert(ndones.toFloat64Array === Array(2d, 2d)) } test("minus") { - val ndones = NDArray.ones(2, 1) - val ndzeros = ndones - 1f + var ndones = NDArray.ones(2, 1) + var ndzeros = ndones - 1f assert(ndzeros.toArray === Array(0f, 0f)) assert((ndones - ndzeros).toArray === Array(1f, 1f)) assert((ndzeros - ndones).toArray === Array(-1f, -1f)) @@ -86,23 +116,46 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { // in-place ndones -= ndones assert(ndones.toArray === Array(0f, 0f)) + + // Float64 methods test + ndones = NDArray.ones(shape = Shape(2, 1)) + ndzeros = ndones - 1d + assert(ndzeros.toFloat64Array === Array(0d, 0d)) + assert((ndones - ndzeros).toFloat64Array === Array(1d , 1d)) + assert((ndzeros - ndones).toFloat64Array === Array(-1d , -1d)) + assert((ndones - 1).toFloat64Array === Array(0d, 0d)) + // in-place + ndones -= ndones + assert(ndones.toArray === Array(0d, 0d)) + } test("multiplication") { - val ndones = NDArray.ones(2, 1) - val ndtwos = ndones * 2 + var ndones = NDArray.ones(2, 1) + var ndtwos = ndones * 2 assert(ndtwos.toArray === Array(2f, 2f)) assert((ndones * ndones).toArray === Array(1f, 1f)) assert((ndtwos * ndtwos).toArray === Array(4f, 4f)) ndtwos *= ndtwos // in-place assert(ndtwos.toArray === Array(4f, 4f)) + + // Float64 methods test + ndones = NDArray.ones(shape = Shape(2, 1), dtype = DType.Float64) + ndtwos = ndones * 2d + assert(ndtwos.toFloat64Array === Array(2d, 2d)) + assert((ndones * ndones).toFloat64Array === Array(1d, 1d)) + assert((ndtwos * ndtwos).toFloat64Array === Array(4d, 4d)) + ndtwos *= ndtwos + // in-place + assert(ndtwos.toFloat64Array === Array(4d, 4d)) + } test("division") { - val ndones = NDArray.ones(2, 1) - val ndzeros = ndones - 1f - val ndhalves = ndones / 2 + var ndones = NDArray.ones(2, 1) + var ndzeros = ndones - 1f + var ndhalves = ndones / 2 assert(ndhalves.toArray === Array(0.5f, 0.5f)) assert((ndhalves / ndhalves).toArray === Array(1f, 1f)) assert((ndones / ndones).toArray === Array(1f, 1f)) @@ -110,37 +163,75 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { ndhalves /= ndhalves // in-place assert(ndhalves.toArray === Array(1f, 1f)) + + // Float64 methods test + ndones = NDArray.ones(shape = Shape (2, 1), dtype = DType.Float64) + ndzeros = ndones - 1d + ndhalves = ndones / 2d + assert(ndhalves.toFloat64Array === Array(0.5d, 0.5d)) + assert((ndhalves / ndhalves).toFloat64Array === Array(1d, 1d)) + assert((ndones / ndones).toFloat64Array === Array(1d, 1d)) + assert((ndzeros / ndones).toFloat64Array === Array(0d, 0d)) + ndhalves /= ndhalves + // in-place + assert(ndhalves.toFloat64Array === Array(1d, 1d)) } test("full") { - val arr = NDArray.full(Shape(1, 2), 3f) + var arr = NDArray.full(Shape(1, 2), 3f) assert(arr.shape === Shape(1, 2)) assert(arr.toArray === Array(3f, 3f)) + + // Float64 methods test + arr = NDArray.full(Shape(1, 2), value = 5d, Context.cpu()) + assert(arr.toFloat64Array === Array (5d, 5d)) } test("clip") { - val ndarray = NDArray.empty(3, 2) + var ndarray = NDArray.empty(3, 2) ndarray.set(Array(1f, 2f, 3f, 4f, 5f, 6f)) assert(NDArray.clip(ndarray, 2f, 5f).toArray === Array(2f, 2f, 3f, 4f, 5f, 5f)) + + // Float64 methods test + ndarray = NDArray.empty(shape = Shape(3, 2), dtype = DType.Float64) + ndarray.set(Array(1d, 2d, 3d, 4d, 5d, 6d)) + assert(NDArray.clip(ndarray, 2d, 5d).toFloat64Array === Array(2d, 2d, 3d, 4d, 5d, 5d)) } test("sqrt") { - val ndarray = NDArray.empty(4, 1) + var ndarray = NDArray.empty(4, 1) ndarray.set(Array(0f, 1f, 4f, 9f)) assert(NDArray.sqrt(ndarray).toArray === Array(0f, 1f, 2f, 3f)) + + // Float64 methods test + ndarray = NDArray.empty(shape = Shape(4, 1), dtype = DType.Float64) + ndarray.set(Array(0d, 1d, 4d, 9d)) + assert(NDArray.sqrt(ndarray).toFloat64Array === Array(0d, 1d, 2d, 3d)) } test("rsqrt") { - val ndarray = NDArray.array(Array(1f, 4f), shape = Shape(2, 1)) + var ndarray = NDArray.array(Array(1f, 4f), shape = Shape(2, 1)) assert(NDArray.rsqrt(ndarray).toArray === Array(1f, 0.5f)) + + // Float64 methods test + ndarray = NDArray.array(Array(1d, 4d, 25d), shape = Shape(3, 1), Context.cpu()) + assert(NDArray.rsqrt(ndarray).toFloat64Array === Array(1d, 0.5d, 0.2d)) } test("norm") { - val ndarray = NDArray.empty(3, 1) + var ndarray = NDArray.empty(3, 1) ndarray.set(Array(1f, 2f, 3f)) - val normed = NDArray.norm(ndarray) + var normed = NDArray.norm(ndarray) assert(normed.shape === Shape(1)) assert(normed.toScalar === math.sqrt(14.0).toFloat +- 1e-3f) + + // Float64 methods test + ndarray = NDArray.empty(shape = Shape(3, 1), dtype = DType.Float64) + ndarray.set(Array(1d, 2d, 3d)) + normed = NDArray.norm(ndarray) + assert(normed.get.dtype === DType.Float64) + assert(normed.shape === Shape(1)) + assert(normed.toFloat64Scalar === math.sqrt(14.0) +- 1e-3d) } test("one hot encode") { @@ -176,25 +267,26 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { } test("power") { - val arr = NDArray.array(Array(3f, 5f), shape = Shape(2, 1)) + var arr = NDArray.array(Array(3f, 5f), shape = Shape(2, 1)) - val arrPower1 = NDArray.power(2f, arr) + var arrPower1 = NDArray.power(2f, arr) assert(arrPower1.shape === Shape(2, 1)) assert(arrPower1.toArray === Array(8f, 32f)) - val arrPower2 = NDArray.power(arr, 2f) + var arrPower2 = NDArray.power(arr, 2f) assert(arrPower2.shape === Shape(2, 1)) assert(arrPower2.toArray === Array(9f, 25f)) - val arrPower3 = NDArray.power(arr, arr) + var arrPower3 = NDArray.power(arr, arr) assert(arrPower3.shape === Shape(2, 1)) assert(arrPower3.toArray === Array(27f, 3125f)) - val arrPower4 = arr ** 2f + var arrPower4 = arr ** 2f + assert(arrPower4.shape === Shape(2, 1)) assert(arrPower4.toArray === Array(9f, 25f)) - val arrPower5 = arr ** arr + var arrPower5 = arr ** arr assert(arrPower5.shape === Shape(2, 1)) assert(arrPower5.toArray === Array(27f, 3125f)) @@ -206,84 +298,211 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { arr **= arr assert(arr.shape === Shape(2, 1)) assert(arr.toArray === Array(27f, 3125f)) + + // Float64 tests + arr = NDArray.array(Array(3d, 5d), shape = Shape(2, 1)) + + arrPower1 = NDArray.power(2d, arr) + assert(arrPower1.shape === Shape(2, 1)) + assert(arrPower1.dtype === DType.Float64) + assert(arrPower1.toFloat64Array === Array(8d, 32d)) + + arrPower2 = NDArray.power(arr, 2d) + assert(arrPower2.shape === Shape(2, 1)) + assert(arrPower2.dtype === DType.Float64) + assert(arrPower2.toFloat64Array === Array(9d, 25d)) + + arrPower3 = NDArray.power(arr, arr) + assert(arrPower3.shape === Shape(2, 1)) + assert(arrPower3.dtype === DType.Float64) + assert(arrPower3.toFloat64Array === Array(27d, 3125d)) + + arrPower4 = arr ** 2f + assert(arrPower4.shape === Shape(2, 1)) + assert(arrPower4.dtype === DType.Float64) + assert(arrPower4.toFloat64Array === Array(9d, 25d)) + + arrPower5 = arr ** arr + assert(arrPower5.shape === Shape(2, 1)) + assert(arrPower5.dtype === DType.Float64) + assert(arrPower5.toFloat64Array === Array(27d, 3125d)) + + arr **= 2d + assert(arr.shape === Shape(2, 1)) + assert(arr.dtype === DType.Float64) + assert(arr.toFloat64Array === Array(9d, 25d)) + + arr.set(Array(3d, 5d)) + arr **= arr + assert(arr.shape === Shape(2, 1)) + assert(arr.dtype === DType.Float64) + assert(arr.toFloat64Array === Array(27d, 3125d)) } test("equal") { - val arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2)) - val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2)) + var arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2)) + var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2)) - val arrEqual1 = NDArray.equal(arr1, arr2) + var arrEqual1 = NDArray.equal(arr1, arr2) assert(arrEqual1.shape === Shape(2, 2)) assert(arrEqual1.toArray === Array(1f, 0f, 1f, 0f)) - val arrEqual2 = NDArray.equal(arr1, 3f) + var arrEqual2 = NDArray.equal(arr1, 3f) assert(arrEqual2.shape === Shape(2, 2)) assert(arrEqual2.toArray === Array(0f, 0f, 1f, 0f)) + + + // Float64 methods test + arr1 = NDArray.array(Array(1d, 2d, 3d, 5d), shape = Shape(2, 2)) + arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2)) + + arrEqual1 = NDArray.equal(arr1, arr2) + assert(arrEqual1.shape === Shape(2, 2)) + assert(arrEqual1.dtype === DType.Float64) + assert(arrEqual1.toFloat64Array === Array(1d, 0d, 1d, 0d)) + + arrEqual2 = NDArray.equal(arr1, 3d) + assert(arrEqual2.shape === Shape(2, 2)) + assert(arrEqual2.dtype === DType.Float64) + assert(arrEqual2.toFloat64Array === Array(0d, 0d, 1d, 0d)) } test("not_equal") { - val arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2)) - val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2)) + var arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2)) + var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2)) - val arrEqual1 = NDArray.notEqual(arr1, arr2) + var arrEqual1 = NDArray.notEqual(arr1, arr2) assert(arrEqual1.shape === Shape(2, 2)) assert(arrEqual1.toArray === Array(0f, 1f, 0f, 1f)) - val arrEqual2 = NDArray.notEqual(arr1, 3f) + var arrEqual2 = NDArray.notEqual(arr1, 3f) assert(arrEqual2.shape === Shape(2, 2)) assert(arrEqual2.toArray === Array(1f, 1f, 0f, 1f)) + + // Float64 methods test + + arr1 = NDArray.array(Array(1d, 2d, 3d, 5d), shape = Shape(2, 2)) + arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2)) + + arrEqual1 = NDArray.notEqual(arr1, arr2) + assert(arrEqual1.shape === Shape(2, 2)) + assert(arrEqual1.dtype === DType.Float64) + assert(arrEqual1.toFloat64Array === Array(0d, 1d, 0d, 1d)) + + arrEqual2 = NDArray.notEqual(arr1, 3d) + assert(arrEqual2.shape === Shape(2, 2)) + assert(arrEqual2.dtype === DType.Float64) + assert(arrEqual2.toFloat64Array === Array(1d, 1d, 0d, 1d)) + } test("greater") { - val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2)) - val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2)) + var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2)) + var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2)) - val arrEqual1 = arr1 > arr2 + var arrEqual1 = arr1 > arr2 assert(arrEqual1.shape === Shape(2, 2)) assert(arrEqual1.toArray === Array(0f, 0f, 1f, 0f)) - val arrEqual2 = arr1 > 2f + var arrEqual2 = arr1 > 2f assert(arrEqual2.shape === Shape(2, 2)) assert(arrEqual2.toArray === Array(0f, 0f, 1f, 1f)) + + // Float64 methods test + arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2)) + arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2)) + + arrEqual1 = arr1 > arr2 + assert(arrEqual1.shape === Shape(2, 2)) + assert(arrEqual1.dtype === DType.Float64) + assert(arrEqual1.toFloat64Array === Array(0d, 0d, 1d, 0d)) + + arrEqual2 = arr1 > 2d + assert(arrEqual2.shape === Shape(2, 2)) + assert(arrEqual2.dtype === DType.Float64) + assert(arrEqual2.toFloat64Array === Array(0d, 0d, 1d, 1d)) } test("greater_equal") { - val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2)) - val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2)) + var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2)) + var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2)) - val arrEqual1 = arr1 >= arr2 + var arrEqual1 = arr1 >= arr2 assert(arrEqual1.shape === Shape(2, 2)) assert(arrEqual1.toArray === Array(1f, 0f, 1f, 0f)) - val arrEqual2 = arr1 >= 2f + var arrEqual2 = arr1 >= 2f assert(arrEqual2.shape === Shape(2, 2)) assert(arrEqual2.toArray === Array(0f, 1f, 1f, 1f)) + + // Float64 methods test + arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2)) + arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2)) + + arrEqual1 = arr1 >= arr2 + assert(arrEqual1.shape === Shape(2, 2)) + assert(arrEqual1.dtype === DType.Float64) + assert(arrEqual1.toFloat64Array === Array(1d, 0d, 1d, 0d)) + + arrEqual2 = arr1 >= 2d + assert(arrEqual2.shape === Shape(2, 2)) + assert(arrEqual2.dtype === DType.Float64) + assert(arrEqual2.toFloat64Array === Array(0d, 1d, 1d, 1d)) } test("lesser") { - val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2)) - val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2)) + var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2)) + var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2)) - val arrEqual1 = arr1 < arr2 + var arrEqual1 = arr1 < arr2 assert(arrEqual1.shape === Shape(2, 2)) assert(arrEqual1.toArray === Array(0f, 1f, 0f, 1f)) - val arrEqual2 = arr1 < 2f + var arrEqual2 = arr1 < 2f assert(arrEqual2.shape === Shape(2, 2)) assert(arrEqual2.toArray === Array(1f, 0f, 0f, 0f)) + + // Float64 methods test + arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2)) + arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2)) + + arrEqual1 = arr1 < arr2 + assert(arrEqual1.shape === Shape(2, 2)) + assert(arrEqual1.dtype === DType.Float64) + assert(arrEqual1.toFloat64Array === Array(0d, 1d, 0d, 1d)) + + arrEqual2 = arr1 < 2d + assert(arrEqual2.shape === Shape(2, 2)) + assert(arrEqual2.dtype === DType.Float64) + assert(arrEqual2.toFloat64Array === Array(1d, 0d, 0d, 0d)) + } test("lesser_equal") { - val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2)) - val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2)) + var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2)) + var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2)) - val arrEqual1 = arr1 <= arr2 + var arrEqual1 = arr1 <= arr2 assert(arrEqual1.shape === Shape(2, 2)) assert(arrEqual1.toArray === Array(1f, 1f, 0f, 1f)) - val arrEqual2 = arr1 <= 2f + var arrEqual2 = arr1 <= 2f assert(arrEqual2.shape === Shape(2, 2)) assert(arrEqual2.toArray === Array(1f, 1f, 0f, 0f)) + + // Float64 methods test + arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2)) + arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2)) + + arrEqual1 = arr1 <= arr2 + assert(arrEqual1.shape === Shape(2, 2)) + assert(arrEqual1.dtype === DType.Float64) + assert(arrEqual1.toFloat64Array === Array(1d, 1d, 0d, 1d)) + + arrEqual2 = arr1 <= 2d + assert(arrEqual2.shape === Shape(2, 2)) + assert(arrEqual2.dtype === DType.Float64) + assert(arrEqual2.toFloat64Array === Array(1d, 1d, 0d, 0d)) } test("choose_element_0index") { @@ -294,11 +513,18 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { } test("copy to") { - val source = NDArray.array(Array(1f, 2f, 3f), shape = Shape(1, 3)) - val dest = NDArray.empty(1, 3) + var source = NDArray.array(Array(1f, 2f, 3f), shape = Shape(1, 3)) + var dest = NDArray.empty(1, 3) source.copyTo(dest) assert(dest.shape === Shape(1, 3)) assert(dest.toArray === Array(1f, 2f, 3f)) + + // Float64 methods test + source = NDArray.array(Array(1d, 2d, 3d), shape = Shape(1, 3)) + dest = NDArray.empty(shape = Shape(1, 3), dtype = DType.Float64) + source.copyTo(dest) + assert(dest.dtype === DType.Float64) + assert(dest.toFloat64Array === Array(1d, 2d, 3d)) } test("abs") { @@ -365,6 +591,12 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { val arr = NDArray.maximum(arr1, arr2) assert(arr.shape === Shape(3, 1)) assert(arr.toArray === Array(4f, 2.1f, 3.7f)) + + // Float64 methods test + val arr3 = NDArray.array(Array(1d, 2d, 3d), shape = Shape(3, 1)) + val maxArr = NDArray.maximum(arr3, 10d) + assert(maxArr.shape === Shape(3, 1)) + assert(maxArr.toArray === Array(10d, 10d, 10d)) } test("min") { @@ -378,11 +610,18 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { val arr = NDArray.minimum(arr1, arr2) assert(arr.shape === Shape(3, 1)) assert(arr.toArray === Array(1.5f, 1f, 3.5f)) + + // Float64 methods test + val arr3 = NDArray.array(Array(4d, 5d, 6d), shape = Shape(3, 1)) + val minArr = NDArray.minimum(arr3, 5d) + assert(minArr.shape === Shape(3, 1)) + assert(minArr.toFloat64Array === Array(4d, 5d, 5d)) } test("sum") { - val arr = NDArray.array(Array(1f, 2f, 3f, 4f), shape = Shape(2, 2)) + var arr = NDArray.array(Array(1f, 2f, 3f, 4f), shape = Shape(2, 2)) assert(NDArray.sum(arr).toScalar === 10f +- 1e-3f) + } test("argmaxChannel") { @@ -398,6 +637,12 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { val arr = NDArray.concatenate(arr1, arr2) assert(arr.shape === Shape(3, 3)) assert(arr.toArray === Array(1f, 2f, 4f, 3f, 3f, 3f, 8f, 7f, 6f)) + + // Try concatenating float32 arr with float64 arr. Should get exception + intercept[Exception] { + val arr3 = NDArray.array(Array (5d, 6d, 7d), shape = Shape(1, 3)) + NDArray.concatenate(Array(arr1, arr3)) + } } test("concatenate axis-1") { @@ -406,6 +651,12 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { val arr = NDArray.concatenate(Array(arr1, arr2), axis = 1) assert(arr.shape === Shape(2, 3)) assert(arr.toArray === Array(1f, 2f, 5f, 3f, 4f, 6f)) + + // Try concatenating float32 arr with float64 arr. Should get exception + intercept[Exception] { + val arr3 = NDArray.array(Array (5d, 6d), shape = Shape(2, 1)) + NDArray.concatenate(Array(arr1, arr3), axis = 1) + } } test("transpose") { @@ -428,6 +679,24 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { val loadedArray = arrays(0) assert(loadedArray.shape === Shape(3, 1)) assert(loadedArray.toArray === Array(1f, 2f, 3f)) + assert(loadedArray.dtype === DType.Float32) + } finally { + val file = new File(filename) + file.delete() + } + + // Try the same for Float64 array + try { + val ndarray = NDArray.array(Array(1d, 2d, 3d), shape = Shape(3, 1), ctx = Context.cpu()) + NDArray.save(filename, Map("local" -> ndarray)) + val (keys, arrays) = NDArray.load(filename) + assert(keys.length === 1) + assert(keys(0) === "local") + assert(arrays.length === 1) + val loadedArray = arrays(0) + assert(loadedArray.shape === Shape(3, 1)) + assert(loadedArray.toArray === Array(1d, 2d, 3d)) + assert(loadedArray.dtype === DType.Float64) } finally { val file = new File(filename) file.delete() @@ -446,6 +715,24 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { val loadedArray = arrays(0) assert(loadedArray.shape === Shape(3, 1)) assert(loadedArray.toArray === Array(1f, 2f, 3f)) + assert(loadedArray.dtype === DType.Float32) + } finally { + val file = new File(filename) + file.delete() + } + + // Try the same thing for Float64 array : + + try { + val ndarray = NDArray.array(Array(1d, 2d, 3d), shape = Shape(3, 1), ctx = Context.cpu()) + NDArray.save(filename, Array(ndarray)) + val (keys, arrays) = NDArray.load(filename) + assert(keys.length === 0) + assert(arrays.length === 1) + val loadedArray = arrays(0) + assert(loadedArray.shape === Shape(3, 1)) + assert(loadedArray.toArray === Array(1d, 2d, 3d)) + assert(loadedArray.dtype === DType.Float64) } finally { val file = new File(filename) file.delete() @@ -464,9 +751,11 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { val ndarray2 = NDArray.array(Array(1f, 2f, 3f), shape = Shape(3, 1)) val ndarray3 = NDArray.array(Array(1f, 2f, 3f), shape = Shape(1, 3)) val ndarray4 = NDArray.array(Array(3f, 2f, 3f), shape = Shape(3, 1)) + val ndarray5 = NDArray.array(Array(3d, 2d, 3d), shape = Shape(3, 1), ctx = Context.cpu()) ndarray1 shouldEqual ndarray2 ndarray1 shouldNot equal(ndarray3) ndarray1 shouldNot equal(ndarray4) + ndarray5 shouldNot equal(ndarray3) } test("slice") { @@ -545,6 +834,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { val bytes = arr.serialize() val arrCopy = NDArray.deserialize(bytes) assert(arr === arrCopy) + assert(arrCopy.dtype === DType.Float32) } test("dtype int32") { @@ -580,18 +870,22 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { test("NDArray random module is generated properly") { val lam = NDArray.ones(1, 2) val rnd = NDArray.random.poisson(lam = Some(lam), shape = Some(Shape(3, 4))) - val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3, 4))) + val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3, 4)), + dtype = Some("float64")) assert(rnd.shape === Shape(1, 2, 3, 4)) assert(rnd2.shape === Shape(3, 4)) + assert(rnd2.head.dtype === DType.Float64) } test("NDArray random module is generated properly - special case of 'normal'") { val mu = NDArray.ones(1, 2) val sigma = NDArray.ones(1, 2) * 2 val rnd = NDArray.random.normal(mu = Some(mu), sigma = Some(sigma), shape = Some(Shape(3, 4))) - val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(3, 4))) + val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(3, 4)), + dtype = Some("float64")) assert(rnd.shape === Shape(1, 2, 3, 4)) assert(rnd2.shape === Shape(3, 4)) + assert(rnd2.head.dtype === DType.Float64) } test("Generated api") { diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala index f6c283c3dfb2..9f0430eaada6 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala @@ -19,6 +19,7 @@ package org.apache.mxnetexamples.imclassification import java.util.concurrent._ +import org.apache.mxnet.DType.DType import org.apache.mxnetexamples.imclassification.models._ import org.apache.mxnetexamples.imclassification.util.Trainer import org.apache.mxnet._ @@ -42,12 +43,13 @@ object TrainModel { * @return The final validation accuracy */ def test(model: String, dataPath: String, numExamples: Int = 60000, - numEpochs: Int = 10, benchmark: Boolean = false): Float = { + numEpochs: Int = 10, benchmark: Boolean = false, + dtype: DType = DType.Float32): Float = { ResourceScope.using() { val devs = Array(Context.cpu(0)) val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String] val (dataLoader, net) = dataLoaderAndModel("mnist", model, dataPath, - numExamples = numExamples, benchmark = benchmark) + numExamples = numExamples, benchmark = benchmark, dtype = dtype) val Acc = Trainer.fit(batchSize = 128, numExamples, devs = devs, network = net, dataLoader = dataLoader, kvStore = "local", numEpochs = numEpochs) @@ -69,7 +71,7 @@ object TrainModel { */ def dataLoaderAndModel(dataset: String, model: String, dataDir: String = "", numLayers: Int = 50, numExamples: Int = 60000, - benchmark: Boolean = false + benchmark: Boolean = false, dtype: DType = DType.Float32 ): ((Int, KVStore) => (DataIter, DataIter), Symbol) = { val (imageShape, numClasses) = dataset match { case "mnist" => (List(1, 28, 28), 10) @@ -80,16 +82,17 @@ object TrainModel { val List(channels, height, width) = imageShape val dataSize: Int = channels * height * width val (datumShape, net) = model match { - case "mlp" => (List(dataSize), MultiLayerPerceptron.getSymbol(numClasses)) - case "lenet" => (List(channels, height, width), Lenet.getSymbol(numClasses)) + case "mlp" => (List(dataSize), MultiLayerPerceptron.getSymbol(numClasses, dtype = dtype)) + case "lenet" => (List(channels, height, width), Lenet.getSymbol(numClasses, dtype = dtype)) case "resnet" => (List(channels, height, width), Resnet.getSymbol(numClasses, - numLayers, imageShape)) + numLayers, imageShape, dtype = dtype)) case _ => throw new Exception("Invalid model name") } val dataLoader: (Int, KVStore) => (DataIter, DataIter) = if (benchmark) { (batchSize: Int, kv: KVStore) => { - val iter = new SyntheticDataIter(numClasses, batchSize, datumShape, List(), numExamples) + val iter = new SyntheticDataIter(numClasses, batchSize, datumShape, List(), numExamples, + dtype) (iter, iter) } } else { @@ -116,8 +119,10 @@ object TrainModel { val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME") else inst.dataDir + val dtype = DType.withName(inst.dType) + val (dataLoader, net) = dataLoaderAndModel(inst.dataset, inst.network, dataPath, - inst.numLayers, inst.numExamples, inst.benchmark) + inst.numLayers, inst.numExamples, inst.benchmark, dtype) val devs = if (inst.gpus != null) inst.gpus.split(',').map(id => Context.gpu(id.trim.toInt)) @@ -210,5 +215,8 @@ class TrainModel { private val numWorker: Int = 1 @Option(name = "--num-server", usage = "# of servers") private val numServer: Int = 1 + @Option(name = "--dtype", usage = "data type of the model to train. " + + "Can be float32/float64. Works only with synthetic data currently") + private val dType: String = "float32" } diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala index 9421f1021619..e4d3b2ae7c3e 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala @@ -24,7 +24,7 @@ import scala.collection.immutable.ListMap import scala.util.Random class SyntheticDataIter(numClasses: Int, val batchSize: Int, datumShape: List[Int], - labelShape: List[Int], maxIter: Int, dtype: DType = DType.Float32 + labelShape: List[Int], maxIter: Int, dType: DType = DType.Float32 ) extends DataIter { var curIter = 0 val random = new Random() @@ -35,12 +35,12 @@ class SyntheticDataIter(numClasses: Int, val batchSize: Int, datumShape: List[In var label: IndexedSeq[NDArray] = IndexedSeq( NDArray.api.random_uniform(Some(0f), Some(maxLabel), shape = Some(batchLabelShape))) var data: IndexedSeq[NDArray] = IndexedSeq( - NDArray.api.random_uniform(shape = Some(shape))) + NDArray.api.random_uniform(shape = Some(shape), dtype = Some(dType.toString))) val provideDataDesc: IndexedSeq[DataDesc] = IndexedSeq( - new DataDesc("data", shape, dtype, Layout.UNDEFINED)) + new DataDesc("data", shape, data(0).dtype, Layout.UNDEFINED)) val provideLabelDesc: IndexedSeq[DataDesc] = IndexedSeq( - new DataDesc("softmax_label", batchLabelShape, dtype, Layout.UNDEFINED)) + new DataDesc("softmax_label", batchLabelShape, label(0).dtype, Layout.UNDEFINED)) val getPad: Int = 0 override def getData(): IndexedSeq[NDArray] = data diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala index 76fb7bb66022..6f8b138d5ccb 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala @@ -17,6 +17,7 @@ package org.apache.mxnetexamples.imclassification.models +import org.apache.mxnet.DType.DType import org.apache.mxnet._ object Lenet { @@ -26,8 +27,8 @@ object Lenet { * @param numClasses Number of classes to classify into * @return model symbol */ - def getSymbol(numClasses: Int): Symbol = { - val data = Symbol.Variable("data") + def getSymbol(numClasses: Int, dtype: DType = DType.Float32): Symbol = { + val data = Symbol.Variable("data", dType = dtype) // first conv val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5), num_filter = 20) val tanh1 = Symbol.api.tanh(data = Some(conv1)) diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala index 5d880bbe0619..089b65f24a65 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala @@ -17,6 +17,7 @@ package org.apache.mxnetexamples.imclassification.models +import org.apache.mxnet.DType.DType import org.apache.mxnet._ object MultiLayerPerceptron { @@ -26,8 +27,8 @@ object MultiLayerPerceptron { * @param numClasses Number of classes to classify into * @return model symbol */ - def getSymbol(numClasses: Int): Symbol = { - val data = Symbol.Variable("data") + def getSymbol(numClasses: Int, dtype: DType = DType.Float32): Symbol = { + val data = Symbol.Variable("data", dType = dtype) val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128, name = "fc1") val act1 = Symbol.api.Activation(data = Some(fc1), "relu", name = "relu") diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala index c3f43d97e898..e5f597680f99 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala @@ -17,6 +17,7 @@ package org.apache.mxnetexamples.imclassification.models +import org.apache.mxnet.DType.DType import org.apache.mxnet._ object Resnet { @@ -77,13 +78,14 @@ object Resnet { */ def resnet(units: List[Int], numStages: Int, filterList: List[Int], numClasses: Int, imageShape: List[Int], bottleNeck: Boolean = true, bnMom: Float = 0.9f, - workspace: Int = 256, dtype: String = "float32", memonger: Boolean = false): Symbol = { + workspace: Int = 256, dtype: DType = DType.Float32, + memonger: Boolean = false): Symbol = { assert(units.size == numStages) var data = Symbol.Variable("data", shape = Shape(List(4) ::: imageShape), dType = DType.Float32) - if (dtype == "float32") { + if (dtype == DType.Float32) { data = Symbol.api.identity(Some(data), "id") - } else if (dtype == "float16") { - data = Symbol.api.cast(Some(data), "float16") + } else if (dtype == DType.Float16) { + data = Symbol.api.cast(Some(data), DType.Float16.toString) } data = Symbol.api.BatchNorm(Some(data), fix_gamma = Some(true), eps = Some(2e-5), momentum = Some(bnMom), name = "bn_data") @@ -118,8 +120,8 @@ object Resnet { kernel = Some(Shape(7, 7)), pool_type = Some("avg"), name = "pool1") val flat = Symbol.api.Flatten(Some(pool1)) var fc1 = Symbol.api.FullyConnected(Some(flat), num_hidden = numClasses, name = "fc1") - if (dtype == "float16") { - fc1 = Symbol.api.cast(Some(fc1), "float32") + if (dtype == DType.Float16) { + fc1 = Symbol.api.cast(Some(fc1), DType.Float32.toString) } Symbol.api.SoftmaxOutput(Some(fc1), name = "softmax") } @@ -134,7 +136,7 @@ object Resnet { * @return Model symbol */ def getSymbol(numClasses: Int, numLayers: Int, imageShape: List[Int], convWorkspace: Int = 256, - dtype: String = "float32"): Symbol = { + dtype: DType = DType.Float32): Symbol = { val List(channels, height, width) = imageShape val (numStages, units, filterList, bottleNeck): (Int, List[Int], List[Int], Boolean) = if (height <= 28) { diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala index 6e9667abe9c0..0daba5a97d77 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala @@ -19,7 +19,7 @@ package org.apache.mxnetexamples.imclassification import java.io.File -import org.apache.mxnet.Context +import org.apache.mxnet.{Context, DType} import org.apache.mxnetexamples.Util import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.slf4j.LoggerFactory @@ -55,9 +55,15 @@ class IMClassificationExampleSuite extends FunSuite with BeforeAndAfterAll { for(model <- List("mlp", "lenet", "resnet")) { test(s"Example CI: Test Image Classification Model ${model}") { - var context = Context.cpu() val valAccuracy = TrainModel.test(model, "", 10, 1, benchmark = true) } } + for(model <- List("mlp", "lenet", "resnet")) { + test(s"Example CI: Test Image Classification Model ${model} with Float64 input") { + val valAccuracy = TrainModel.test(model, "", 10, 1, benchmark = true, + dtype = DType.Float64) + } + } + } diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala index 5208923275f6..bf6581588114 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala @@ -17,9 +17,10 @@ package org.apache.mxnet.infer -import org.apache.mxnet.{Context, DataDesc, NDArray} +import org.apache.mxnet._ import java.io.File +import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE import org.slf4j.LoggerFactory import scala.io @@ -30,13 +31,13 @@ trait ClassifierBase { /** * Takes an array of floats and returns corresponding (Label, Score) tuples - * @param input Indexed sequence one-dimensional array of floats + * @param input Indexed sequence one-dimensional array of floats/doubles * @param topK (Optional) How many result (sorting based on the last axis) * elements to return. Default returns unsorted output. * @return Indexed sequence of (Label, Score) tuples */ - def classify(input: IndexedSeq[Array[Float]], - topK: Option[Int] = None): IndexedSeq[(String, Float)] + def classify[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]], + topK: Option[Int] = None): IndexedSeq[(String, T)] /** * Takes a sequence of NDArrays and returns (Label, Score) tuples @@ -78,17 +79,35 @@ class Classifier(modelPathPrefix: String, /** * Takes flat arrays as input and returns (Label, Score) tuples. - * @param input Indexed sequence one-dimensional array of floats + * @param input Indexed sequence one-dimensional array of floats/doubles * @param topK (Optional) How many result (sorting based on the last axis) * elements to return. Default returns unsorted output. * @return Indexed sequence of (Label, Score) tuples */ - override def classify(input: IndexedSeq[Array[Float]], - topK: Option[Int] = None): IndexedSeq[(String, Float)] = { + override def classify[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]], + topK: Option[Int] = None): IndexedSeq[(String, T)] = { + + // considering only the first output + val result = input(0)(0) match { + case d: Double => { + classifyImpl(input.asInstanceOf[IndexedSeq[Array[Double]]], topK) + } + case _ => { + classifyImpl(input.asInstanceOf[IndexedSeq[Array[Float]]], topK) + } + } + + result.asInstanceOf[IndexedSeq[(String, T)]] + } + + private def classifyImpl[B, A <: MX_PRIMITIVE_TYPE] + (input: IndexedSeq[Array[B]], topK: Option[Int] = None)(implicit ev: B => A) + : IndexedSeq[(String, B)] = { // considering only the first output val predictResult = predictor.predict(input)(0) - var result: IndexedSeq[(String, Float)] = IndexedSeq.empty + + var result: IndexedSeq[(String, B)] = IndexedSeq.empty if (topK.isDefined) { val sortedIndex = predictResult.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get) @@ -105,7 +124,7 @@ class Classifier(modelPathPrefix: String, * @param input Indexed sequence of NDArrays * @param topK (Optional) How many result (sorting based on the last axis) * elements to return. Default returns unsorted output. - * @return Traversable sequence of (Label, Score) tuples + * @return Traversable sequence of (Label, Score) tuples. */ override def classifyWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int] = None) : IndexedSeq[IndexedSeq[(String, Float)]] = { @@ -113,7 +132,7 @@ class Classifier(modelPathPrefix: String, // considering only the first output // Copy NDArray to CPU to avoid frequent GPU to CPU copying val predictResultND: NDArray = - predictor.predictWithNDArray(input)(0).asInContext(Context.cpu()) + predictor.predictWithNDArray(input)(0).asInContext(Context.cpu()) // Parallel Execution with ParArray for better performance val predictResultPar: ParArray[Array[Float]] = new ParArray[Array[Float]](predictResultND.shape(0)) diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala index 96be12179d42..3c80f9226399 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala @@ -17,7 +17,8 @@ package org.apache.mxnet.infer -import org.apache.mxnet.{Context, DataDesc, NDArray, Shape} +import org.apache.mxnet.DType.DType +import org.apache.mxnet._ import scala.collection.mutable.ListBuffer @@ -70,14 +71,18 @@ class ImageClassifier(modelPathPrefix: String, * * @param inputImage Path prefix of the input image * @param topK Number of result elements to return, sorted by probability + * @param dType The precision at which to run the inference. + * specify the DType as DType.Float64 for Double precision. + * Defaults to DType.Float32 * @return List of list of tuples of (Label, Probability) */ - def classifyImage(inputImage: BufferedImage, - topK: Option[Int] = None): IndexedSeq[IndexedSeq[(String, Float)]] = { + def classifyImage + (inputImage: BufferedImage, topK: Option[Int] = None, dType: DType = DType.Float32): + IndexedSeq[IndexedSeq[(String, Float)]] = { val scaledImage = ImageClassifier.reshapeImage(inputImage, width, height) val imageShape = inputShape.drop(1) - val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape) + val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape, dType) val imgWithBatchNum = NDArray.api.expand_dims(pixelsNDArray, 0) inputImage.flush() scaledImage.flush() @@ -95,16 +100,19 @@ class ImageClassifier(modelPathPrefix: String, * * @param inputBatch Input array of buffered images * @param topK Number of result elements to return, sorted by probability + * @param dType The precision at which to run the inference. + * specify the DType as DType.Float64 for Double precision. + * Defaults to DType.Float32 * @return List of list of tuples of (Label, Probability) */ - def classifyImageBatch(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None): - IndexedSeq[IndexedSeq[(String, Float)]] = { + def classifyImageBatch(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None, + dType: DType = DType.Float32): IndexedSeq[IndexedSeq[(String, Float)]] = { val inputBatchSeq = inputBatch.toIndexedSeq val imageBatch = inputBatchSeq.indices.par.map(idx => { val scaledImage = ImageClassifier.reshapeImage(inputBatchSeq(idx), width, height) val imageShape = inputShape.drop(1) - val imgND = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape) + val imgND = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape, dType) val imgWithBatch = NDArray.api.expand_dims(imgND, 0).get handler.execute(imgND.dispose()) imgWithBatch @@ -152,11 +160,29 @@ object ImageClassifier { * returned by this method after the use. *

* @param resizedImage BufferedImage to get pixels from + * * @param inputImageShape Input shape; for example for resnet it is (3,224,224). Should be same as inputDescriptor shape. + * @param dType The DataType of the NDArray created from the image + * that should be returned. + * Currently it defaults to Dtype.Float32 * @return NDArray pixels array with shape (3, 224, 224) in CHW format */ - def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape): NDArray = { + def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape, + dType : DType = DType.Float32): NDArray = { + + if (dType == DType.Float64) { + val result = getFloatPixelsArray(resizedImage) + NDArray.array(result.map(_.toDouble), shape = inputImageShape) + } + else { + val result = getFloatPixelsArray(resizedImage) + NDArray.array(result, shape = inputImageShape) + } + } + + private def getFloatPixelsArray(resizedImage: BufferedImage): Array[Float] = { + // Get height and width of the image val w = resizedImage.getWidth() val h = resizedImage.getHeight() @@ -166,7 +192,6 @@ object ImageClassifier { // 3 times height and width for R,G,B channels val result = new Array[Float](3 * h * w) - var row = 0 // copy pixels to array vertically while (row < h) { @@ -184,11 +209,10 @@ object ImageClassifier { } row += 1 } + resizedImage.flush() - // creating NDArray according to the input shape - val pixelsArray = NDArray.array(result, shape = inputImageShape) - pixelsArray + result } /** diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala index d4bce9f0d71e..67692a316cc4 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala @@ -17,8 +17,9 @@ package org.apache.mxnet.infer +import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE import org.apache.mxnet.io.NDArrayIter -import org.apache.mxnet.{Context, DataDesc, NDArray, Shape} +import org.apache.mxnet._ import org.apache.mxnet.module.Module import scala.collection.mutable.ListBuffer @@ -36,11 +37,13 @@ private[infer] trait PredictBase { *

* This method will take input as IndexedSeq one dimensional arrays and creates the * NDArray needed for inference. The array will be reshaped based on the input descriptors. - * @param input: An IndexedSequence of a one-dimensional array. + * @param input: An Indexed Sequence of a one-dimensional array of datatype + * Float or Double An IndexedSequence is needed when the model has more than one input. * @return Indexed sequence array of outputs */ - def predict(input: IndexedSeq[Array[Float]]): IndexedSeq[Array[Float]] + def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]]) + : IndexedSeq[Array[T]] /** * Predict using NDArray as input. @@ -123,13 +126,13 @@ class Predictor(modelPathPrefix: String, * Takes input as IndexedSeq one dimensional arrays and creates the NDArray needed for inference * The array will be reshaped based on the input descriptors. * - * @param input: An IndexedSequence of a one-dimensional array. + * @param input: An IndexedSequence of a one-dimensional array + * of data type Float or Double. An IndexedSequence is needed when the model has more than one input. * @return Indexed sequence array of outputs */ - override def predict(input: IndexedSeq[Array[Float]]) - : IndexedSeq[Array[Float]] = { - + override def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]]) + : IndexedSeq[Array[T]] = { require(input.length == inputDescriptors.length, s"number of inputs provided: ${input.length} does not match number of inputs " + s"in inputDescriptors: ${inputDescriptors.length}") @@ -139,12 +142,30 @@ class Predictor(modelPathPrefix: String, s"number of elements:${i.length} in the input does not match the shape:" + s"${d.shape.toString()}") } + + // Infer the dtype of input and call relevant method + val result = input(0)(0) match { + case d: Double => predictImpl(input.asInstanceOf[IndexedSeq[Array[Double]]]) + case _ => predictImpl(input.asInstanceOf[IndexedSeq[Array[Float]]]) + } + + result.asInstanceOf[IndexedSeq[Array[T]]] + } + + private def predictImpl[B, A <: MX_PRIMITIVE_TYPE] + (input: IndexedSeq[Array[B]])(implicit ev: B => A) + : IndexedSeq[Array[B]] = { + var inputND: ListBuffer[NDArray] = ListBuffer.empty[NDArray] for((i, d) <- input.zip(inputDescriptors)) { val shape = d.shape.toVector.patch(from = batchIndex, patch = Vector(1), replaced = 1) - - inputND += mxNetHandler.execute(NDArray.array(i, Shape(shape))) + if (d.dtype == DType.Float64) { + inputND += mxNetHandler.execute(NDArray.array(i.asInstanceOf[Array[Double]], Shape(shape))) + } + else { + inputND += mxNetHandler.execute(NDArray.array(i.asInstanceOf[Array[Float]], Shape(shape))) + } } // rebind with batchsize 1 @@ -158,7 +179,8 @@ class Predictor(modelPathPrefix: String, val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter( inputND.toIndexedSeq, dataBatchSize = 1))) - val result = resultND.map((f : NDArray) => f.toArray) + val result = + resultND.map((f : NDArray) => if (f.dtype == DType.Float64) f.toFloat64Array else f.toArray) mxNetHandler.execute(inputND.foreach(_.dispose)) mxNetHandler.execute(resultND.foreach(_.dispose)) @@ -168,9 +190,11 @@ class Predictor(modelPathPrefix: String, mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false, forceRebind = true)) } - result + result.asInstanceOf[IndexedSeq[Array[B]]] } + + /** * Predict using NDArray as input * This method is useful when the input is a batch of data diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala index 0466693be9bc..146fe93105e4 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala @@ -72,6 +72,30 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor) predictor.predict(input).toArray } + /** + * Takes input as Array of one dimensional arrays and creates the NDArray needed for inference + * The array will be reshaped based on the input descriptors. Example of calling in Java: + * + *

+    * {@code
+    * double tmp[][] = new double[1][224];
+    * for (int x = 0; x < 1; x++)
+    *   for (int y = 0; y < 224; y++)
+    *     tmp[x][y] = (int)(Math.random()*10);
+    * predictor.predict(tmp);
+    * }
+    * 
+ * + * @param input: An Array of a one-dimensional array. + An extra Array is needed for when the model has more than one input. + * @return Indexed sequence array of outputs + */ + + def predict(input: Array[Array[Double]]): + Array[Array[Double]] = { + predictor.predict(input).toArray + } + /** * Takes input as List of one dimensional arrays and creates the NDArray needed for inference * The array will be reshaped based on the input descriptors. diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala index b28aeba1deed..d9ccec468791 100644 --- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala +++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala @@ -22,7 +22,7 @@ import java.nio.file.{Files, Paths} import java.util import org.apache.mxnet.module.Module -import org.apache.mxnet.{Context, DataDesc, NDArray, Shape} +import org.apache.mxnet.{Context, DType, DataDesc, NDArray, Shape} import org.mockito.Matchers._ import org.mockito.Mockito import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -127,6 +127,29 @@ class ClassifierSuite extends FunSuite with BeforeAndAfterAll { } + test("ClassifierSuite-flatFloat64Array-topK") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputData = Array.fill[Double](12)(1d) + + val predictResult : IndexedSeq[Array[Double]] = + IndexedSeq[Array[Double]](Array(.98d, 0.97d, 0.96d, 0.99d)) + + val testClassifier = new MyClassifier(modelPath, inputDescriptor) + + Mockito.doReturn(predictResult).when(testClassifier.predictor) + .predict(any(classOf[IndexedSeq[Array[Double]]])) + + val result: IndexedSeq[(String, Double)] = testClassifier. + classify(IndexedSeq(inputData), topK = Some(10)) + + assert((result(0)_2).getClass == 1d.getClass) + + assertResult(predictResult(0).sortBy(-_)) { + result.map(_._2).toArray + } + + } + test("ClassifierSuite-flatArrayInput") { val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) val inputData = Array.fill[Float](12)(1) @@ -147,6 +170,28 @@ class ClassifierSuite extends FunSuite with BeforeAndAfterAll { } } + test("ClassifierSuite-flatArrayFloat64Input") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputData = Array.fill[Double](12)(1d) + + val predictResult : IndexedSeq[Array[Double]] = + IndexedSeq[Array[Double]](Array(.98d, 0.97d, 0.96d, 0.99d)) + + val testClassifier = new MyClassifier(modelPath, inputDescriptor) + + Mockito.doReturn(predictResult).when(testClassifier.predictor) + .predict(any(classOf[IndexedSeq[Array[Double]]])) + + val result: IndexedSeq[(String, Double)] = testClassifier. + classify(IndexedSeq(inputData)) + + assert((result(0)_2).getClass == 1d.getClass) + + assertResult(predictResult(0)) { + result.map(_._2).toArray + } + } + test("ClassifierSuite-NDArray1InputWithoutTopK") { val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) val inputDataShape = Shape(1, 3, 2, 2) diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala index 1c291e1e7b3c..5198c4a1f309 100644 --- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala +++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala @@ -68,6 +68,10 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { val result = ImageClassifier.bufferedImageToPixels(image2, Shape(3, 2, 2)) assert(result.shape == inputDescriptor(0).shape.drop(1)) + assert(result.dtype == DType.Float32) + + val resultFloat64 = ImageClassifier.bufferedImageToPixels(image2, Shape(3, 2, 2), DType.Float64) + assert(resultFloat64.dtype == DType.Float64) } test("ImageClassifierSuite-testWithInputImage") { @@ -106,8 +110,10 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { predictResult(i).map(_._2).toArray } } + } + test("ImageClassifierSuite-testWithInputBatchImage") { val dType = DType.Float32 val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512), @@ -152,4 +158,5 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { } } } + } diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala index 509ffb35db8d..9afbc9b3d4a8 100644 --- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala +++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala @@ -19,7 +19,7 @@ package org.apache.mxnet.infer import org.apache.mxnet.io.NDArrayIter import org.apache.mxnet.module.{BaseModule, Module} -import org.apache.mxnet.{DataDesc, Layout, NDArray, Shape} +import org.apache.mxnet._ import org.mockito.Matchers._ import org.mockito.Mockito import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -91,6 +91,36 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll { , any[Option[BaseModule]], any[String]) } + test("PredictorSuite-testWithFlatFloat64Arrays") { + + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2), + layout = Layout.NCHW, dtype = DType.Float64)) + val inputData = Array.fill[Double](12)(1d) + + // this will disposed at the end of the predict call on Predictor. + val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2), dtype = DType.Float64)) + + val testPredictor = new MyPredictor("xyz", inputDescriptor) + + Mockito.doReturn(predictResult).when(testPredictor.mockModule) + .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean]) + + val testFun = testPredictor.predict(IndexedSeq(inputData)) + + assert(testFun.size == 1, "output size should be 1 ") + + assert(testFun(0)(0).getClass == 1d.getClass) + + assert(Array.fill[Double](12)(1d).mkString == testFun(0).mkString) + + // Verify that the module was bound with batch size 1 and rebound back to the original + // input descriptor. the number of times is twice here because loadModule overrides the + // initial bind. + Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]], + any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean] + , any[Option[BaseModule]], any[String]) + } + test("PredictorSuite-testWithNDArray") { val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2), layout = Layout.NCHW)) diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc index d684c6d13564..ea6e9c8f5ba4 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc @@ -424,6 +424,15 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromCPU return ret; } +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFromCPU + (JNIEnv *env, jobject obj, jlong arrayPtr, jdoubleArray sourceArr, jint arrSize) { + jdouble *sourcePtr = env->GetDoubleArrayElements(sourceArr, NULL); + int ret = MXNDArraySyncCopyFromCPU(reinterpret_cast(arrayPtr), + static_cast(sourcePtr), arrSize); + env->ReleaseDoubleArrayElements(sourceArr, sourcePtr, 0); + return ret; +} + JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetContext (JNIEnv *env, jobject obj, jlong arrayPtr, jobject devTypeId, jobject devId) { int outDevType; diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h index 40230ac6daae..7e8e03de9124 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h @@ -175,6 +175,14 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromCPU (JNIEnv *, jobject, jlong, jfloatArray, jint); +/* + * Class: org_apache_mxnet_LibInfo + * Method: mxFloat64NDArraySyncCopyFromCPU + * Signature: (J[DI)I + */ +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFromCPU + (JNIEnv *, jobject, jlong, jdoubleArray, jint); + /* * Class: org_apache_mxnet_LibInfo * Method: mxNDArrayLoad