diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index b9f84d592a70..5b5fdce712f1 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -193,6 +193,7 @@ List of Contributors
* [Yuxi Hu](https://github.com/yuxihu)
* [Harsh Patel](https://github.com/harshp8l)
* [Xiao Wang](https://github.com/BeyonderXX)
+* [Piyush Ghai](https://github.com/piyushghai)
Label Bot
---------
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
index b2b23da6274e..224a39275dac 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
@@ -18,6 +18,7 @@
(ns org.apache.clojure-mxnet.infer
(:refer-clojure :exclude [type])
(:require [org.apache.clojure-mxnet.context :as context]
+ [org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.shape :as shape]
[org.apache.clojure-mxnet.util :as util]
@@ -62,10 +63,12 @@
(defprotocol AImageClassifier
(classify-image
[wrapped-image-classifier image]
- [wrapped-image-classifier image topk])
+ [wrapped-image-classifier image topk]
+ [wrapped-image-classifier image topk dtype])
(classify-image-batch
[wrapped-image-classifier images]
- [wrapped-image-classifier images topk]))
+ [wrapped-image-classifier images topk]
+ [wrapped-image-classifier images topk dtype]))
(defprotocol AObjectDetector
(detect-objects
@@ -80,7 +83,8 @@
(extend-protocol APredictor
WrappedPredictor
- (predict [wrapped-predictor inputs]
+ (predict
+ [wrapped-predictor inputs]
(util/validate! ::wrapped-predictor wrapped-predictor
"Invalid predictor")
(util/validate! ::vec-of-float-arrays inputs
@@ -101,62 +105,50 @@
(extend-protocol AClassifier
WrappedClassifier
- (classify [wrapped-classifier inputs]
- (util/validate! ::wrapped-classifier wrapped-classifier
- "Invalid classifier")
- (util/validate! ::vec-of-float-arrays inputs
- "Invalid inputs")
- (classify wrapped-classifier inputs nil))
- (classify [wrapped-classifier inputs topk]
- (util/validate! ::wrapped-classifier wrapped-classifier
- "Invalid classifier")
- (util/validate! ::vec-of-float-arrays inputs
- "Invalid inputs")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classify (:classifier wrapped-classifier)
- (util/vec->indexed-seq inputs)
- (util/->int-option topk))))
- (classify-with-ndarray [wrapped-classifier inputs]
- (util/validate! ::wrapped-classifier wrapped-classifier
- "Invalid classifier")
- (util/validate! ::vec-of-ndarrays inputs
- "Invalid inputs")
- (classify-with-ndarray wrapped-classifier inputs nil))
- (classify-with-ndarray [wrapped-classifier inputs topk]
- (util/validate! ::wrapped-classifier wrapped-classifier
- "Invalid classifier")
- (util/validate! ::vec-of-ndarrays inputs
- "Invalid inputs")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classifyWithNDArray (:classifier wrapped-classifier)
- (util/vec->indexed-seq inputs)
- (util/->int-option topk))))
+ (classify
+ ([wrapped-classifier inputs]
+ (classify wrapped-classifier inputs nil))
+ ([wrapped-classifier inputs topk]
+ (util/validate! ::wrapped-classifier wrapped-classifier
+ "Invalid classifier")
+ (util/validate! ::vec-of-float-arrays inputs
+ "Invalid inputs")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.classify (:classifier wrapped-classifier)
+ (util/vec->indexed-seq inputs)
+ (util/->int-option topk)))))
+ (classify-with-ndarray
+ ([wrapped-classifier inputs]
+ (classify-with-ndarray wrapped-classifier inputs nil))
+ ([wrapped-classifier inputs topk]
+ (util/validate! ::wrapped-classifier wrapped-classifier
+ "Invalid classifier")
+ (util/validate! ::vec-of-ndarrays inputs
+ "Invalid inputs")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.classifyWithNDArray (:classifier wrapped-classifier)
+ (util/vec->indexed-seq inputs)
+ (util/->int-option topk)))))
WrappedImageClassifier
- (classify [wrapped-image-classifier inputs]
- (util/validate! ::wrapped-image-classifier wrapped-image-classifier
- "Invalid classifier")
- (util/validate! ::vec-of-float-arrays inputs
- "Invalid inputs")
- (classify wrapped-image-classifier inputs nil))
- (classify [wrapped-image-classifier inputs topk]
- (util/validate! ::wrapped-image-classifier wrapped-image-classifier
- "Invalid classifier")
- (util/validate! ::vec-of-float-arrays inputs
- "Invalid inputs")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classify (:image-classifier wrapped-image-classifier)
- (util/vec->indexed-seq inputs)
- (util/->int-option topk))))
- (classify-with-ndarray [wrapped-image-classifier inputs]
- (util/validate! ::wrapped-image-classifier wrapped-image-classifier
- "Invalid classifier")
- (util/validate! ::vec-of-ndarrays inputs
- "Invalid inputs")
- (classify-with-ndarray wrapped-image-classifier inputs nil))
- (classify-with-ndarray [wrapped-image-classifier inputs topk]
+ (classify
+ ([wrapped-image-classifier inputs]
+ (classify wrapped-image-classifier inputs nil))
+ ([wrapped-image-classifier inputs topk]
+ (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+ "Invalid classifier")
+ (util/validate! ::vec-of-float-arrays inputs
+ "Invalid inputs")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.classify (:image-classifier wrapped-image-classifier)
+ (util/vec->indexed-seq inputs)
+ (util/->int-option topk)))))
+ (classify-with-ndarray
+ ([wrapped-image-classifier inputs]
+ (classify-with-ndarray wrapped-image-classifier inputs nil))
+ ([wrapped-image-classifier inputs topk]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::vec-of-ndarrays inputs
@@ -165,83 +157,83 @@
(util/coerce-return-recursive
(.classifyWithNDArray (:image-classifier wrapped-image-classifier)
(util/vec->indexed-seq inputs)
- (util/->int-option topk)))))
+ (util/->int-option topk))))))
(s/def ::image #(instance? BufferedImage %))
+(s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64})
(extend-protocol AImageClassifier
WrappedImageClassifier
- (classify-image [wrapped-image-classifier image]
- (util/validate! ::wrapped-image-classifier wrapped-image-classifier
- "Invalid classifier")
- (util/validate! ::image image "Invalid image")
- (classify-image wrapped-image-classifier image nil))
- (classify-image [wrapped-image-classifier image topk]
- (util/validate! ::wrapped-image-classifier wrapped-image-classifier
- "Invalid classifier")
- (util/validate! ::image image "Invalid image")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classifyImage (:image-classifier wrapped-image-classifier)
- image
- (util/->int-option topk))))
- (classify-image-batch [wrapped-image-classifier images]
- (util/validate! ::wrapped-image-classifier wrapped-image-classifier
- "Invalid classifier")
- (classify-image-batch wrapped-image-classifier images nil))
- (classify-image-batch [wrapped-image-classifier images topk]
- (util/validate! ::wrapped-image-classifier wrapped-image-classifier
- "Invalid classifier")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.classifyImageBatch (:image-classifier wrapped-image-classifier)
- images
- (util/->int-option topk)))))
+ (classify-image
+ ([wrapped-image-classifier image]
+ (classify-image wrapped-image-classifier image nil dtype/FLOAT32))
+ ([wrapped-image-classifier image topk]
+ (classify-image wrapped-image-classifier image topk dtype/FLOAT32))
+ ([wrapped-image-classifier image topk dtype]
+ (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+ "Invalid classifier")
+ (util/validate! ::image image "Invalid image")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/validate! ::dtype dtype "Invalid dtype")
+ (util/coerce-return-recursive
+ (.classifyImage (:image-classifier wrapped-image-classifier)
+ image
+ (util/->int-option topk)
+ dtype))))
+ (classify-image-batch
+ ([wrapped-image-classifier images]
+ (classify-image-batch wrapped-image-classifier images nil dtype/FLOAT32))
+ ([wrapped-image-classifier images topk]
+ (classify-image-batch wrapped-image-classifier images topk dtype/FLOAT32))
+ ([wrapped-image-classifier images topk dtype]
+ (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+ "Invalid classifier")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/validate! ::dtype dtype "Invalid dtype")
+ (util/coerce-return-recursive
+ (.classifyImageBatch (:image-classifier wrapped-image-classifier)
+ images
+ (util/->int-option topk)
+ dtype)))))
(extend-protocol AObjectDetector
WrappedObjectDetector
- (detect-objects [wrapped-detector image]
- (util/validate! ::wrapped-detector wrapped-detector
- "Invalid object detector")
- (util/validate! ::image image "Invalid image")
- (detect-objects wrapped-detector image nil))
- (detect-objects [wrapped-detector image topk]
- (util/validate! ::wrapped-detector wrapped-detector
- "Invalid object detector")
- (util/validate! ::image image "Invalid image")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.imageObjectDetect (:object-detector wrapped-detector)
- image
- (util/->int-option topk))))
- (detect-objects-batch [wrapped-detector images]
- (util/validate! ::wrapped-detector wrapped-detector
- "Invalid object detector")
- (detect-objects-batch wrapped-detector images nil))
- (detect-objects-batch [wrapped-detector images topk]
- (util/validate! ::wrapped-detector wrapped-detector
- "Invalid object detector")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.imageBatchObjectDetect (:object-detector wrapped-detector)
- images
- (util/->int-option topk))))
- (detect-objects-with-ndarrays [wrapped-detector input-arrays]
- (util/validate! ::wrapped-detector wrapped-detector
- "Invalid object detector")
- (util/validate! ::vec-of-ndarrays input-arrays
- "Invalid inputs")
- (detect-objects-with-ndarrays wrapped-detector input-arrays nil))
- (detect-objects-with-ndarrays [wrapped-detector input-arrays topk]
+ (detect-objects
+ ([wrapped-detector image]
+ (detect-objects wrapped-detector image nil))
+ ([wrapped-detector image topk]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
- (util/validate! ::vec-of-ndarrays input-arrays
- "Invalid inputs")
- (util/validate! ::nil-or-int topk "Invalid top-K")
- (util/coerce-return-recursive
- (.objectDetectWithNDArray (:object-detector wrapped-detector)
- (util/vec->indexed-seq input-arrays)
+ (util/validate! ::image image "Invalid image")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.imageObjectDetect (:object-detector wrapped-detector)
+ image
+ (util/->int-option topk)))))
+ (detect-objects-batch
+ ([wrapped-detector images]
+ (detect-objects-batch wrapped-detector images nil))
+ ([wrapped-detector images topk]
+ (util/validate! ::wrapped-detector wrapped-detector
+ "Invalid object detector")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.imageBatchObjectDetect (:object-detector wrapped-detector)
+ images
(util/->int-option topk)))))
+ (detect-objects-with-ndarrays
+ ([wrapped-detector input-arrays]
+ (detect-objects-with-ndarrays wrapped-detector input-arrays nil))
+ ([wrapped-detector input-arrays topk]
+ (util/validate! ::wrapped-detector wrapped-detector
+ "Invalid object detector")
+ (util/validate! ::vec-of-ndarrays input-arrays
+ "Invalid inputs")
+ (util/validate! ::nil-or-int topk "Invalid top-K")
+ (util/coerce-return-recursive
+ (.objectDetectWithNDArray (:object-detector wrapped-detector)
+ (util/vec->indexed-seq input-arrays)
+ (util/->int-option topk))))))
(defprotocol AInferenceFactory
(create-predictor [factory] [factory opts])
@@ -335,7 +327,7 @@
[image input-shape-vec]
(util/validate! ::image image "Invalid image")
(util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector")
- (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec)))
+ (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec) dtype/FLOAT32))
(s/def ::image-path string?)
(s/def ::image-paths (s/coll-of ::image-path))
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj
new file mode 100644
index 000000000000..0967df2289d8
--- /dev/null
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj
@@ -0,0 +1,46 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.primitives
+ (:import (org.apache.mxnet MX_PRIMITIVES$MX_FLOAT MX_PRIMITIVES$MX_Double
+ MX_PRIMITIVES$MX_PRIMITIVE_TYPE)))
+
+
+;;; Defines customer mx primitives that can be used for mathematical computations
+;;; in NDArrays to control precision. Currently Float and Double are supported
+
+;;; For purposes of automatic conversion in ndarray functions, doubles are default
+;; to specify using floats you must use a Float
+
+(defn mx-float
+ "Creates a MXNet float primitive"
+ [num]
+ (new MX_PRIMITIVES$MX_FLOAT num))
+
+(defn mx-double
+ "Creates a MXNet double primitive"
+ [num]
+ (new MX_PRIMITIVES$MX_Double num))
+
+(defn ->num
+ "Returns the underlying number value"
+ [primitive]
+ (.data primitive))
+
+(defn primitive? [x]
+ (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE x))
+
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
index 21e31baa3a9b..43970c0abd79 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
@@ -19,6 +19,7 @@
(:require [clojure.spec.alpha :as s]
[t6.from-scala.core :refer [$ $$] :as $]
[clojure.string :as string]
+ [org.apache.clojure-mxnet.primitives :as primitives]
[org.apache.clojure-mxnet.shape :as mx-shape])
(:import (org.apache.mxnet NDArray)
(scala Product Tuple2 Tuple3)
@@ -36,7 +37,8 @@
"byte<>" "byte-array"
"java.lang.String<>" "vec-or-strings"
"org.apache.mxnet.NDArray" "ndarray"
- "org.apache.mxnet.Symbol" "sym"})
+ "org.apache.mxnet.Symbol" "sym"
+ "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE" "double-or-float"})
(def symbol-param-coerce {"java.lang.String" "sym-name"
"float" "num"
@@ -144,6 +146,8 @@
(and (get targets "int<>") (vector? param)) (int-array param)
(and (get targets "float<>") (vector? param)) (float-array param)
(and (get targets "java.lang.String<>") (vector? param)) (into-array param)
+ (and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (instance? Float param)) (primitives/mx-float param)
+ (and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (number? param)) (primitives/mx-double param)
:else param))
(defn nil-or-coerce-param [param targets]
@@ -177,6 +181,7 @@
(instance? Map return-val) (scala-map->map return-val)
(instance? Tuple2 return-val) (tuple->vec return-val)
(instance? Tuple3 return-val) (tuple->vec return-val)
+ (primitives/primitive? return-val) (primitives/->num return-val)
:else return-val))
(defn coerce-return-recursive [return-val]
diff --git a/contrib/clojure-package/test/good-test-ndarray.clj b/contrib/clojure-package/test/good-test-ndarray.clj
index 3b53b1906006..b048a819c642 100644
--- a/contrib/clojure-package/test/good-test-ndarray.clj
+++ b/contrib/clojure-package/test/good-test-ndarray.clj
@@ -27,11 +27,12 @@
(defn
div
- ([ndarray num-or-ndarray]
+ ([ndarray ndarray-or-double-or-float]
(util/coerce-return
(.$div
ndarray
(util/coerce-param
- num-or-ndarray
- #{"float" "org.apache.mxnet.NDArray"})))))
+ ndarray-or-double-or-float
+ #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"
+ "org.apache.mxnet.NDArray"})))))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
index 9badfed933a5..b459b06132b2 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
@@ -40,7 +40,11 @@
(deftest test-single-classification
(let [classifier (create-classifier)
image (infer/load-image-from-file "test/test-images/kitten.jpg")
- [predictions] (infer/classify-image classifier image 5)]
+ [predictions-all] (infer/classify-image classifier image)
+ [predictions-with-default-dtype] (infer/classify-image classifier image 10)
+ [predictions] (infer/classify-image classifier image 5 dtype/FLOAT32)]
+ (is (= 1000 (count predictions-all)))
+ (is (= 10 (count predictions-with-default-dtype)))
(is (some? predictions))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
@@ -58,8 +62,12 @@
(let [classifier (create-classifier)
image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
"test/test-images/Pug-Cookie.jpg"])
- batch-predictions (infer/classify-image-batch classifier image-batch 5)
+ batch-predictions-all (infer/classify-image-batch classifier image-batch)
+ batch-predictions-with-default-dtype (infer/classify-image-batch classifier image-batch 10)
+ batch-predictions (infer/classify-image-batch classifier image-batch 5 dtype/FLOAT32)
predictions (first batch-predictions)]
+ (is (= 1000 (count (first batch-predictions-all))))
+ (is (= 10 (count (first batch-predictions-with-default-dtype))))
(is (some? batch-predictions))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
index 788a59491095..3a0e3d30a1d9 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
@@ -40,9 +40,11 @@
(deftest test-single-detection
(let [detector (create-detector)
image (infer/load-image-from-file "test/test-images/kitten.jpg")
+ [predictions-all] (infer/detect-objects detector image)
[predictions] (infer/detect-objects detector image 5)]
(is (some? predictions))
(is (= 5 (count predictions)))
+ (is (= 13 (count predictions-all)))
(is (every? #(= 2 (count %)) predictions))
(is (every? #(string? (first %)) predictions))
(is (every? #(= 5 (count (second %))) predictions))
@@ -53,9 +55,11 @@
(let [detector (create-detector)
image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
"test/test-images/Pug-Cookie.jpg"])
+ batch-predictions-all (infer/detect-objects-batch detector image-batch)
batch-predictions (infer/detect-objects-batch detector image-batch 5)
predictions (first batch-predictions)]
(is (some? batch-predictions))
+ (is (= 13 (count (first batch-predictions-all))))
(is (= 5 (count predictions)))
(is (every? #(= 2 (count %)) predictions))
(is (every? #(string? (first %)) predictions))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
index 79e94412d0df..9ffd3abed2f9 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
@@ -97,7 +97,7 @@
(is (= [1.0 1.0] (->vec ndhalves)))))
(deftest test-full
- (let [nda (full [1 2] 3)]
+ (let [nda (full [1 2] 3.0)]
(is (= (shape nda) (mx-shape/->shape [1 2])))
(is (= [3.0 3.0] (->vec nda)))))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj
new file mode 100644
index 000000000000..1a538e537b8b
--- /dev/null
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj
@@ -0,0 +1,45 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements. See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License. You may obtain a copy of the License at
+;;
+;; http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.primitives-test
+ (:require [org.apache.clojure-mxnet.primitives :as primitives]
+ [clojure.test :refer :all])
+ (:import (org.apache.mxnet MX_PRIMITIVES$MX_PRIMITIVE_TYPE
+ MX_PRIMITIVES$MX_FLOAT
+ MX_PRIMITIVES$MX_Double)))
+
+(deftest test-primitive-types
+ (is (not (primitives/primitive? 3)))
+ (is (primitives/primitive? (primitives/mx-float 3)))
+ (is (primitives/primitive? (primitives/mx-double 3))))
+
+(deftest test-float-primitives
+ (is (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE (primitives/mx-float 3)))
+ (is (instance? MX_PRIMITIVES$MX_FLOAT (primitives/mx-float 3)))
+ (is (instance? Float (-> (primitives/mx-float 3)
+ (primitives/->num))))
+ (is (= 3.0 (-> (primitives/mx-float 3)
+ (primitives/->num)))))
+
+(deftest test-double-primitives
+ (is (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE (primitives/mx-double 2)))
+ (is (instance? MX_PRIMITIVES$MX_Double (primitives/mx-double 2)))
+ (is (instance? Double (-> (primitives/mx-double 2)
+ (primitives/->num))))
+ (is (= 2.0 (-> (primitives/mx-double 2)
+ (primitives/->num)))))
+
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
index bd77a8a0edc6..c26f83d5aa49 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
@@ -20,6 +20,7 @@
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.ndarray :as ndarray]
+ [org.apache.clojure-mxnet.primitives :as primitives]
[org.apache.clojure-mxnet.symbol :as sym]
[org.apache.clojure-mxnet.test-util :as test-util]
[clojure.spec.alpha :as s])
@@ -133,6 +134,9 @@
(is (= "[F" (->> (util/coerce-param [1 2] #{"float<>"}) str (take 2) (apply str))))
(is (= "[L" (->> (util/coerce-param [1 2] #{"java.lang.String<>"}) str (take 2) (apply str))))
+ (is (primitives/primitive? (util/coerce-param 1.0 #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"})))
+ (is (primitives/primitive? (util/coerce-param (float 1.0) #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"})))
+
(is (= 1 (util/coerce-param 1 #{"unknown"}))))
(deftest test-nil-or-coerce-param
@@ -171,6 +175,12 @@
(util/convert-tuple [1 2]))))
(is (= [1 2 3] (util/coerce-return
(util/convert-tuple [1 2 3]))))
+
+ (is (instance? Double (util/coerce-return (primitives/mx-double 3))))
+ (is (= 3.0 (util/coerce-return (primitives/mx-double 3))))
+ (is (instance? Float (util/coerce-return (primitives/mx-float 2))))
+ (is (= 2.0 (util/coerce-return (primitives/mx-float 2))))
+
(is (= "foo" (util/coerce-return "foo"))))
(deftest test-translate-keyword-shape
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
index ed7aff602f63..001bd04d2c95 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
@@ -18,7 +18,9 @@
package org.apache.mxnet
import org.apache.mxnet.util.NativeLibraryLoader
-import org.slf4j.{LoggerFactory, Logger}
+import org.slf4j.{Logger, LoggerFactory}
+
+import scala.Specializable.Group
private[mxnet] object Base {
private val logger: Logger = LoggerFactory.getLogger("MXNetJVM")
@@ -57,6 +59,9 @@ private[mxnet] object Base {
val MX_REAL_TYPE = DType.Float32
+ // The primitives currently supported for NDArray operations
+ val MX_PRIMITIVES = new Group ((Double, Float))
+
try {
try {
tryLoadLibraryOS("mxnet-scala")
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
index 0a5683aa7ab3..20b6ed9fc806 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
@@ -93,6 +93,9 @@ private[mxnet] class LibInfo {
@native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle,
source: Array[MXFloat],
size: Int): Int
+ @native def mxFloat64NDArraySyncCopyFromCPU(handle: NDArrayHandle,
+ source: Array[Double],
+ size: Int): Int
@native def mxNDArrayLoad(fname: String,
outSize: MXUintRef,
handles: ArrayBuffer[NDArrayHandle],
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
new file mode 100644
index 000000000000..cb978856963c
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+object MX_PRIMITIVES {
+
+ /**
+ * This defines the basic primitives we can use in Scala for mathematical
+ * computations in NDArrays.This gives us a flexibility to expand to
+ * more supported primitives in the future. Currently Float and Double
+ * are supported. The functions which accept MX_PRIMITIVE_TYPE as input can also accept
+ * plain old Float and Double data as inputs because of the underlying
+ * implicit conversion between primitives to MX_PRIMITIVE_TYPE.
+ */
+ trait MX_PRIMITIVE_TYPE extends Ordered[MX_PRIMITIVE_TYPE]{
+
+ def toString: String
+
+ def unary_- : MX_PRIMITIVE_TYPE
+ }
+
+ trait MXPrimitiveOrdering extends Ordering[MX_PRIMITIVE_TYPE] {
+
+ def compare(x: MX_PRIMITIVE_TYPE, y: MX_PRIMITIVE_TYPE): Int = x.compare(y)
+
+ }
+
+ implicit object MX_PRIMITIVE_TYPE extends MXPrimitiveOrdering
+
+ /**
+ * Wrapper over Float in Scala.
+ * @param data
+ */
+ class MX_FLOAT(val data: Float) extends MX_PRIMITIVE_TYPE {
+
+ override def toString: String = data.toString
+
+ override def unary_- : MX_PRIMITIVE_TYPE = new MX_FLOAT(data.unary_-)
+
+ override def compare(that: MX_PRIMITIVE_TYPE): Int = {
+ this.data.compareTo(that.asInstanceOf[MX_FLOAT].data)
+ }
+ }
+
+ implicit def FloatToMX_Float(d : Float): MX_FLOAT = new MX_FLOAT(d)
+
+ implicit def MX_FloatToFloat(d: MX_FLOAT) : Float = d.data
+
+ implicit def IntToMX_Float(d: Int): MX_FLOAT = new MX_FLOAT(d.toFloat)
+
+ /**
+ * Wrapper over Double in Scala.
+ * @param data
+ */
+ class MX_Double(val data: Double) extends MX_PRIMITIVE_TYPE {
+
+ override def toString: String = data.toString
+
+ override def unary_- : MX_PRIMITIVE_TYPE = new MX_Double(data.unary_-)
+
+ override def compare(that: MX_PRIMITIVE_TYPE): Int = {
+ this.data.compareTo(that.asInstanceOf[MX_Double].data)
+ }
+ }
+
+ implicit def DoubleToMX_Double(d : Double): MX_Double = new MX_Double(d)
+
+ implicit def MX_DoubleToDouble(d: MX_Double) : Double = d.data
+
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 125958150b72..163ed2682532 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -21,6 +21,7 @@ import java.nio.{ByteBuffer, ByteOrder}
import org.apache.mxnet.Base._
import org.apache.mxnet.DType.DType
+import org.apache.mxnet.MX_PRIMITIVES.{MX_PRIMITIVE_TYPE}
import org.slf4j.LoggerFactory
import scala.collection.mutable
@@ -262,16 +263,46 @@ object NDArray extends NDArrayBase {
arr
}
- // Perform power operator
+ def full(shape: Shape, value: Double, ctx: Context): NDArray = {
+ val arr = empty(shape, ctx, DType.Float64)
+ arr.set(value)
+ arr
+ }
+
+ /**
+ * Create a new NDArray filled with given value, with specified shape.
+ * @param shape shape of the NDArray.
+ * @param value value to be filled with
+ */
+ def full(shape: Shape, value: Double): NDArray = {
+ full(shape, value, null)
+ }
+
+
+ /**
+ * Perform power operation on NDArray. Returns result as NDArray
+ * @param lhs
+ * @param rhs
+ */
def power(lhs: NDArray, rhs: NDArray): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_power", Seq(lhs, rhs))
}
- def power(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Perform scalar power operation on NDArray. Returns result as NDArray
+ * @param lhs NDArray on which to perform the operation on.
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def power(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_power_scalar", Seq(lhs, rhs))
}
- def power(lhs: Float, rhs: NDArray): NDArray = {
+ /**
+ * Perform scalar power operation on NDArray. Returns result as NDArray
+ * @param lhs The scalar input. Can be of type Float/Double
+ * @param rhs NDArray on which to perform the operation on.
+ */
+ def power(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_rpower_scalar", Seq(lhs, rhs))
}
@@ -280,11 +311,21 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("_maximum", Seq(lhs, rhs))
}
- def maximum(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Perform the max operation on NDArray. Returns the result as NDArray.
+ * @param lhs NDArray on which to perform the operation on.
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def maximum(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs))
}
- def maximum(lhs: Float, rhs: NDArray): NDArray = {
+ /**
+ * Perform the max operation on NDArray. Returns the result as NDArray.
+ * @param lhs The scalar input. Can be of type Float/Double
+ * @param rhs NDArray on which to perform the operation on.
+ */
+ def maximum(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs))
}
@@ -293,11 +334,21 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("_minimum", Seq(lhs, rhs))
}
- def minimum(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Perform the min operation on NDArray. Returns the result as NDArray.
+ * @param lhs NDArray on which to perform the operation on.
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def minimum(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs))
}
- def minimum(lhs: Float, rhs: NDArray): NDArray = {
+ /**
+ * Perform the min operation on NDArray. Returns the result as NDArray.
+ * @param lhs The scalar input. Can be of type Float/Double
+ * @param rhs NDArray on which to perform the operation on.
+ */
+ def minimum(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs))
}
@@ -310,7 +361,15 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_equal", Seq(lhs, rhs))
}
- def equal(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Returns the result of element-wise **equal to** (==) comparison operation with broadcasting.
+ * For each element in input arrays, return 1(true) if corresponding elements are same,
+ * otherwise return 0(false).
+ *
+ * @param lhs NDArray
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def equal(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_equal_scalar", Seq(lhs, rhs))
}
@@ -324,7 +383,15 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_not_equal", Seq(lhs, rhs))
}
- def notEqual(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Returns the result of element-wise **not equal to** (!=) comparison operation
+ * with broadcasting.
+ * For each element in input arrays, return 1(true) if corresponding elements are different,
+ * otherwise return 0(false).
+ * @param lhs NDArray
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def notEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_not_equal_scalar", Seq(lhs, rhs))
}
@@ -338,7 +405,16 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_greater", Seq(lhs, rhs))
}
- def greater(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Returns the result of element-wise **greater than** (>) comparison operation
+ * with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are greater than rhs,
+ * otherwise return 0(false).
+ *
+ * @param lhs NDArray
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def greater(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_greater_scalar", Seq(lhs, rhs))
}
@@ -352,7 +428,16 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_greater_equal", Seq(lhs, rhs))
}
- def greaterEqual(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Returns the result of element-wise **greater than or equal to** (>=) comparison
+ * operation with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are greater than equal to
+ * rhs, otherwise return 0(false).
+ *
+ * @param lhs NDArray
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def greaterEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_greater_equal_scalar", Seq(lhs, rhs))
}
@@ -366,7 +451,15 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_lesser", Seq(lhs, rhs))
}
- def lesser(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Returns the result of element-wise **lesser than** (<) comparison operation
+ * with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are less than rhs,
+ * otherwise return 0(false).
+ * @param lhs NDArray
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def lesser(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_lesser_scalar", Seq(lhs, rhs))
}
@@ -380,7 +473,16 @@ object NDArray extends NDArrayBase {
NDArray.genericNDArrayFunctionInvoke("broadcast_lesser_equal", Seq(lhs, rhs))
}
- def lesserEqual(lhs: NDArray, rhs: Float): NDArray = {
+ /**
+ * Returns the result of element-wise **lesser than or equal to** (<=) comparison
+ * operation with broadcasting.
+ * For each element in input arrays, return 1(true) if lhs elements are
+ * lesser than equal to rhs, otherwise return 0(false).
+ *
+ * @param lhs NDArray
+ * @param rhs The scalar input. Can be of type Float/Double
+ */
+ def lesserEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_lesser_equal_scalar", Seq(lhs, rhs))
}
@@ -397,6 +499,16 @@ object NDArray extends NDArrayBase {
arr
}
+ def array(sourceArr: Array[Double], shape: Shape, ctx: Context): NDArray = {
+ val arr = empty(shape, ctx, dtype = DType.Float64)
+ arr.set(sourceArr)
+ arr
+ }
+
+ def array(sourceArr: Array[Double], shape: Shape): NDArray = {
+ array(sourceArr, shape, null)
+ }
+
/**
* Returns evenly spaced values within a given interval.
* Values are generated within the half-open interval [`start`, `stop`). In other
@@ -645,6 +757,12 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
checkCall(_LIB.mxNDArraySyncCopyFromCPU(handle, source, source.length))
}
+ private def syncCopyfrom(source: Array[Double]): Unit = {
+ require(source.length == size,
+ s"array size (${source.length}) do not match the size of NDArray ($size)")
+ checkCall(_LIB.mxFloat64NDArraySyncCopyFromCPU(handle, source, source.length))
+ }
+
/**
* Return a sliced NDArray that shares memory with current one.
* NDArray only support continuous slicing on axis 0
@@ -759,7 +877,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
* @param value Value to set
* @return Current NDArray
*/
- def set(value: Float): NDArray = {
+ def set(value: MX_PRIMITIVE_TYPE): NDArray = {
require(writable, "trying to assign to a readonly NDArray")
NDArray.genericNDArrayFunctionInvoke("_set_value", Seq(value), Map("out" -> this))
this
@@ -776,11 +894,17 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
this
}
+ def set(other: Array[Double]): NDArray = {
+ require(writable, "trying to assign to a readonly NDArray")
+ syncCopyfrom(other)
+ this
+ }
+
def +(other: NDArray): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_plus", Seq(this, other))
}
- def +(other: Float): NDArray = {
+ def +(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_plus_scalar", Seq(this, other))
}
@@ -792,7 +916,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
this
}
- def +=(other: Float): NDArray = {
+ def +=(other: MX_PRIMITIVE_TYPE): NDArray = {
if (!writable) {
throw new IllegalArgumentException("trying to add to a readonly NDArray")
}
@@ -804,7 +928,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.genericNDArrayFunctionInvoke("_minus", Seq(this, other))
}
- def -(other: Float): NDArray = {
+ def -(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_minus_scalar", Seq(this, other))
}
@@ -816,7 +940,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
this
}
- def -=(other: Float): NDArray = {
+ def -=(other: MX_PRIMITIVE_TYPE): NDArray = {
if (!writable) {
throw new IllegalArgumentException("trying to subtract from a readonly NDArray")
}
@@ -828,7 +952,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.genericNDArrayFunctionInvoke("_mul", Seq(this, other))
}
- def *(other: Float): NDArray = {
+ def *(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_mul_scalar", Seq(this, other))
}
@@ -844,7 +968,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
this
}
- def *=(other: Float): NDArray = {
+ def *=(other: MX_PRIMITIVE_TYPE): NDArray = {
if (!writable) {
throw new IllegalArgumentException("trying to multiply to a readonly NDArray")
}
@@ -856,7 +980,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.genericNDArrayFunctionInvoke("_div", Seq(this, other))
}
- def /(other: Float): NDArray = {
+ def /(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_div_scalar", Seq(this, other))
}
@@ -868,7 +992,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
this
}
- def /=(other: Float): NDArray = {
+ def /=(other: MX_PRIMITIVE_TYPE): NDArray = {
if (!writable) {
throw new IllegalArgumentException("trying to divide from a readonly NDArray")
}
@@ -880,7 +1004,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.power(this, other)
}
- def **(other: Float): NDArray = {
+ def **(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.power(this, other)
}
@@ -888,7 +1012,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.genericNDArrayFunctionInvoke("_power", Seq(this, other), Map("out" -> this))
}
- def **=(other: Float): NDArray = {
+ def **=(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_power_scalar", Seq(this, other), Map("out" -> this))
}
@@ -896,7 +1020,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.greater(this, other)
}
- def >(other: Float): NDArray = {
+ def >(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.greater(this, other)
}
@@ -904,7 +1028,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.greaterEqual(this, other)
}
- def >=(other: Float): NDArray = {
+ def >=(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.greaterEqual(this, other)
}
@@ -912,7 +1036,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.lesser(this, other)
}
- def <(other: Float): NDArray = {
+ def <(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.lesser(this, other)
}
@@ -920,7 +1044,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.lesserEqual(this, other)
}
- def <=(other: Float): NDArray = {
+ def <=(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.lesserEqual(this, other)
}
@@ -928,7 +1052,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
NDArray.genericNDArrayFunctionInvoke("_mod", Seq(this, other))
}
- def %(other: Float): NDArray = {
+ def %(other: MX_PRIMITIVE_TYPE): NDArray = {
NDArray.genericNDArrayFunctionInvoke("_mod_scalar", Seq(this, other))
}
@@ -940,7 +1064,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
this
}
- def %=(other: Float): NDArray = {
+ def %=(other: MX_PRIMITIVE_TYPE): NDArray = {
if (!writable) {
throw new IllegalArgumentException("trying to take modulo from a readonly NDArray")
}
@@ -956,6 +1080,14 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
internal.toFloatArray
}
+ /**
+ * Return a copied flat java array of current array (row-major) with datatype as Float64/Double.
+ * @return A copy of array content.
+ */
+ def toFloat64Array: Array[Double] = {
+ internal.toDoubleArray
+ }
+
def internal: NDArrayInternal = {
val myType = dtype
val arrLength = DType.numOfBytes(myType) * size
@@ -975,6 +1107,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
this.toArray(0)
}
+ def toFloat64Scalar: Double = {
+ require(shape == Shape(1), "The current array is not a scalar")
+ this.toFloat64Array(0)
+ }
+
/**
* Copy the content of current array to other.
*
@@ -997,7 +1134,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
* @return The copy target NDArray
*/
def copyTo(ctx: Context): NDArray = {
- val ret = new NDArray(NDArray.newAllocHandle(shape, ctx, delayAlloc = true))
+ val ret = new NDArray(NDArray.newAllocHandle(shape, ctx, delayAlloc = true, dtype = dtype))
copyTo(ret)
}
@@ -1047,11 +1184,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
private[mxnet] object NDArrayConversions {
implicit def int2Scalar(x: Int): NDArrayConversions = new NDArrayConversions(x.toFloat)
- implicit def double2Scalar(x: Double): NDArrayConversions = new NDArrayConversions(x.toFloat)
+ implicit def double2Scalar(x: Double): NDArrayConversions = new NDArrayConversions(x)
implicit def float2Scalar(x: Float): NDArrayConversions = new NDArrayConversions(x)
}
-private[mxnet] class NDArrayConversions(val value: Float) {
+private[mxnet] class NDArrayConversions(val value: MX_PRIMITIVE_TYPE) {
def +(other: NDArray): NDArray = {
other + value
}
@@ -1145,34 +1282,39 @@ private[mxnet] class NDArrayFuncReturn(private[mxnet] val arr: Array[NDArray]) {
def waitToRead(): Unit = head.waitToRead()
def context: Context = head.context
def set(value: Float): NDArray = head.set(value)
+ def set(value: Double): NDArray = head.set(value)
def set(other: NDArray): NDArray = head.set(other)
def set(other: Array[Float]): NDArray = head.set(other)
+ def set(other: Array[Double]): NDArray = head.set(other)
def +(other: NDArray): NDArray = head + other
- def +(other: Float): NDArray = head + other
+ def +(other: MX_PRIMITIVE_TYPE): NDArray = head + other
def +=(other: NDArray): NDArray = head += other
- def +=(other: Float): NDArray = head += other
+ def +=(other: MX_PRIMITIVE_TYPE): NDArray = head += other
def -(other: NDArray): NDArray = head - other
- def -(other: Float): NDArray = head - other
+ def -(other: MX_PRIMITIVE_TYPE): NDArray = head - other
def -=(other: NDArray): NDArray = head -= other
- def -=(other: Float): NDArray = head -= other
+ def -=(other: MX_PRIMITIVE_TYPE): NDArray = head -= other
def *(other: NDArray): NDArray = head * other
- def *(other: Float): NDArray = head * other
+ def *(other: MX_PRIMITIVE_TYPE): NDArray = head * other
def unary_-(): NDArray = -head
def *=(other: NDArray): NDArray = head *= other
- def *=(other: Float): NDArray = head *= other
+ def *=(other: MX_PRIMITIVE_TYPE): NDArray = head *= other
def /(other: NDArray): NDArray = head / other
+ def /(other: MX_PRIMITIVE_TYPE): NDArray = head / other
def **(other: NDArray): NDArray = head ** other
- def **(other: Float): NDArray = head ** other
+ def **(other: MX_PRIMITIVE_TYPE): NDArray = head ** other
def >(other: NDArray): NDArray = head > other
- def >(other: Float): NDArray = head > other
+ def >(other: MX_PRIMITIVE_TYPE): NDArray = head > other
def >=(other: NDArray): NDArray = head >= other
- def >=(other: Float): NDArray = head >= other
+ def >=(other: MX_PRIMITIVE_TYPE): NDArray = head >= other
def <(other: NDArray): NDArray = head < other
- def <(other: Float): NDArray = head < other
+ def <(other: MX_PRIMITIVE_TYPE): NDArray = head < other
def <=(other: NDArray): NDArray = head <= other
- def <=(other: Float): NDArray = head <= other
+ def <=(other: MX_PRIMITIVE_TYPE): NDArray = head <= other
def toArray: Array[Float] = head.toArray
+ def toFloat64Array: Array[Double] = head.toFloat64Array
def toScalar: Float = head.toScalar
+ def toFloat64Scalar: Double = head.toFloat64Scalar
def copyTo(other: NDArray): NDArray = head.copyTo(other)
def copyTo(ctx: Context): NDArray = head.copyTo(ctx)
def copy(): NDArray = head.copy()
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
index a84bd106b763..e30098c3088b 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
@@ -53,9 +53,9 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
val label = currentBatch.label(0)
// properties
val res = (
- // TODO: need to allow user to specify DType and Layout
- IndexedSeq(new DataDesc(dataName, data.shape, DType.Float32, Layout.UNDEFINED)),
- IndexedSeq(new DataDesc(labelName, label.shape, DType.Float32, Layout.UNDEFINED)),
+ // TODO: need to allow user to specify Layout
+ IndexedSeq(new DataDesc(dataName, data.shape, data.dtype, Layout.UNDEFINED)),
+ IndexedSeq(new DataDesc(labelName, label.shape, label.dtype, Layout.UNDEFINED)),
ListMap(dataName -> data.shape),
ListMap(labelName -> label.shape),
data.shape(0))
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index 0032a54dd802..e690abba0d13 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -61,7 +61,8 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
dataBatchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad",
dataName: String = "data", labelName: String = "label") {
- this(IO.initDataDesc(data, allowEmpty = false, dataName, MX_REAL_TYPE, Layout.UNDEFINED),
+ this(IO.initDataDesc(data, allowEmpty = false, dataName,
+ if (data == null || data.isEmpty) MX_REAL_TYPE else data(0).dtype, Layout.UNDEFINED),
IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE, Layout.UNDEFINED),
dataBatchSize, shuffle, lastBatchHandle)
}
@@ -272,7 +273,7 @@ object NDArrayIter {
*/
def addData(name: String, data: NDArray): Builder = {
this.data = this.data ++ IndexedSeq((new DataDesc(name,
- data.shape, DType.Float32, Layout.UNDEFINED), data))
+ data.shape, data.dtype, Layout.UNDEFINED), data))
this
}
@@ -284,7 +285,7 @@ object NDArrayIter {
*/
def addLabel(name: String, label: NDArray): Builder = {
this.label = this.label ++ IndexedSeq((new DataDesc(name,
- label.shape, DType.Float32, Layout.UNDEFINED), label))
+ label.shape, label.dtype, Layout.UNDEFINED), label))
this
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
index 198102d2377f..67809c158aff 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
@@ -91,17 +91,26 @@ object NDArray extends NDArrayBase {
def full(shape: Shape, value: Float, ctx: Context): NDArray
= org.apache.mxnet.NDArray.full(shape, value, ctx)
+ def full(shape: Shape, value: Double, ctx: Context): NDArray
+ = org.apache.mxnet.NDArray.full(shape, value, ctx)
+
def power(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
def power(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
def power(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+ def power(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+ def power(lhs: Double, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
def maximum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
def maximum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
def maximum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+ def maximum(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+ def maximum(lhs: Double, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
def minimum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
def minimum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
def minimum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+ def minimum(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+ def minimum(lhs: Double, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
/**
@@ -111,6 +120,7 @@ object NDArray extends NDArrayBase {
*/
def equal(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
def equal(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
+ def equal(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
/**
* Returns the result of element-wise **not equal to** (!=) comparison operation
@@ -120,6 +130,7 @@ object NDArray extends NDArrayBase {
*/
def notEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
def notEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
+ def notEqual(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
/**
* Returns the result of element-wise **greater than** (>) comparison operation
@@ -129,6 +140,7 @@ object NDArray extends NDArrayBase {
*/
def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
+ def greater(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
/**
* Returns the result of element-wise **greater than or equal to** (>=) comparison
@@ -140,6 +152,8 @@ object NDArray extends NDArrayBase {
= org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
def greaterEqual(lhs: NDArray, rhs: Float): NDArray
= org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
+ def greaterEqual(lhs: NDArray, rhs: Double): NDArray
+ = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
/**
* Returns the result of element-wise **lesser than** (<) comparison operation
@@ -149,6 +163,7 @@ object NDArray extends NDArrayBase {
*/
def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
+ def lesser(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
/**
* Returns the result of element-wise **lesser than or equal to** (<=) comparison
@@ -160,6 +175,8 @@ object NDArray extends NDArrayBase {
= org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
def lesserEqual(lhs: NDArray, rhs: Float): NDArray
= org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
+ def lesserEqual(lhs: NDArray, rhs: Double): NDArray
+ = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
/**
* Create a new NDArray that copies content from source_array.
@@ -172,6 +189,18 @@ object NDArray extends NDArrayBase {
= org.apache.mxnet.NDArray.array(
sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx)
+ /**
+ * Create a new NDArray that copies content from source_array.
+ * @param sourceArr Source data (list of Doubles) to create NDArray from.
+ * @param shape shape of the NDArray
+ * @param ctx The context of the NDArray, default to current default context.
+ * @return The created NDArray.
+ */
+ def arrayWithDouble(sourceArr: java.util.List[java.lang.Double], shape: Shape,
+ ctx: Context = null): NDArray
+ = org.apache.mxnet.NDArray.array(
+ sourceArr.asScala.map(ele => Double.unbox(ele)).toArray, shape)
+
/**
* Returns evenly spaced values within a given interval.
* Values are generated within the half-open interval [`start`, `stop`). In other
@@ -205,6 +234,10 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
}
+ def this(arr: Array[Double], shape: Shape, ctx: Context) = {
+ this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
+ }
+
def this(arr: java.util.List[java.lang.Float], shape: Shape, ctx: Context) = {
this(NDArray.array(arr, shape, ctx))
}
@@ -304,41 +337,59 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
* @return Current NDArray
*/
def set(value: Float): NDArray = nd.set(value)
+ def set(value: Double): NDArray = nd.set(value)
def set(other: NDArray): NDArray = nd.set(other)
def set(other: Array[Float]): NDArray = nd.set(other)
+ def set(other: Array[Double]): NDArray = nd.set(other)
def add(other: NDArray): NDArray = this.nd + other.nd
def add(other: Float): NDArray = this.nd + other
+ def add(other: Double): NDArray = this.nd + other
def addInplace(other: NDArray): NDArray = this.nd += other
def addInplace(other: Float): NDArray = this.nd += other
+ def addInplace(other: Double): NDArray = this.nd += other
def subtract(other: NDArray): NDArray = this.nd - other
def subtract(other: Float): NDArray = this.nd - other
+ def subtract(other: Double): NDArray = this.nd - other
def subtractInplace(other: NDArray): NDArray = this.nd -= other
def subtractInplace(other: Float): NDArray = this.nd -= other
+ def subtractInplace(other: Double): NDArray = this.nd -= other
def multiply(other: NDArray): NDArray = this.nd * other
def multiply(other: Float): NDArray = this.nd * other
+ def multiply(other: Double): NDArray = this.nd * other
def multiplyInplace(other: NDArray): NDArray = this.nd *= other
def multiplyInplace(other: Float): NDArray = this.nd *= other
+ def multiplyInplace(other: Double): NDArray = this.nd *= other
def div(other: NDArray): NDArray = this.nd / other
def div(other: Float): NDArray = this.nd / other
+ def div(other: Double): NDArray = this.nd / other
def divInplace(other: NDArray): NDArray = this.nd /= other
def divInplace(other: Float): NDArray = this.nd /= other
+ def divInplace(other: Double): NDArray = this.nd /= other
def pow(other: NDArray): NDArray = this.nd ** other
def pow(other: Float): NDArray = this.nd ** other
+ def pow(other: Double): NDArray = this.nd ** other
def powInplace(other: NDArray): NDArray = this.nd **= other
def powInplace(other: Float): NDArray = this.nd **= other
+ def powInplace(other: Double): NDArray = this.nd **= other
def mod(other: NDArray): NDArray = this.nd % other
def mod(other: Float): NDArray = this.nd % other
+ def mod(other: Double): NDArray = this.nd % other
def modInplace(other: NDArray): NDArray = this.nd %= other
def modInplace(other: Float): NDArray = this.nd %= other
+ def modInplace(other: Double): NDArray = this.nd %= other
def greater(other: NDArray): NDArray = this.nd > other
def greater(other: Float): NDArray = this.nd > other
+ def greater(other: Double): NDArray = this.nd > other
def greaterEqual(other: NDArray): NDArray = this.nd >= other
def greaterEqual(other: Float): NDArray = this.nd >= other
+ def greaterEqual(other: Double): NDArray = this.nd >= other
def lesser(other: NDArray): NDArray = this.nd < other
def lesser(other: Float): NDArray = this.nd < other
+ def lesser(other: Double): NDArray = this.nd < other
def lesserEqual(other: NDArray): NDArray = this.nd <= other
def lesserEqual(other: Float): NDArray = this.nd <= other
+ def lesserEqual(other: Double): NDArray = this.nd <= other
/**
* Return a copied flat java array of current array (row-major).
@@ -346,6 +397,12 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
*/
def toArray: Array[Float] = nd.toArray
+ /**
+ * Return a copied flat java array of current array (row-major).
+ * @return A copy of array content.
+ */
+ def toFloat64Array: Array[Double] = nd.toFloat64Array
+
/**
* Return a CPU scalar(float) of current ndarray.
* This ndarray must have shape (1,)
@@ -354,6 +411,14 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
*/
def toScalar: Float = nd.toScalar
+ /**
+ * Return a CPU scalar(float) of current ndarray.
+ * This ndarray must have shape (1,)
+ *
+ * @return The scalar representation of the ndarray.
+ */
+ def toFloat64Scalar: Double = nd.toFloat64Scalar
+
/**
* Copy the content of current array to other.
*
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
index 2659b7848bc6..86c7eb29d2ef 100644
--- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
@@ -40,6 +40,15 @@ public void testCreateNDArray() {
new Shape(new int[]{1, 3}),
new Context("cpu", 0));
assertTrue(Arrays.equals(nd.shape().toArray(), arr));
+
+ List
* This method will take input as IndexedSeq one dimensional arrays and creates the * NDArray needed for inference. The array will be reshaped based on the input descriptors. - * @param input: An IndexedSequence of a one-dimensional array. + * @param input: An Indexed Sequence of a one-dimensional array of datatype + * Float or Double An IndexedSequence is needed when the model has more than one input. * @return Indexed sequence array of outputs */ - def predict(input: IndexedSeq[Array[Float]]): IndexedSeq[Array[Float]] + def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]]) + : IndexedSeq[Array[T]] /** * Predict using NDArray as input. @@ -123,13 +126,13 @@ class Predictor(modelPathPrefix: String, * Takes input as IndexedSeq one dimensional arrays and creates the NDArray needed for inference * The array will be reshaped based on the input descriptors. * - * @param input: An IndexedSequence of a one-dimensional array. + * @param input: An IndexedSequence of a one-dimensional array + * of data type Float or Double. An IndexedSequence is needed when the model has more than one input. * @return Indexed sequence array of outputs */ - override def predict(input: IndexedSeq[Array[Float]]) - : IndexedSeq[Array[Float]] = { - + override def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]]) + : IndexedSeq[Array[T]] = { require(input.length == inputDescriptors.length, s"number of inputs provided: ${input.length} does not match number of inputs " + s"in inputDescriptors: ${inputDescriptors.length}") @@ -139,12 +142,30 @@ class Predictor(modelPathPrefix: String, s"number of elements:${i.length} in the input does not match the shape:" + s"${d.shape.toString()}") } + + // Infer the dtype of input and call relevant method + val result = input(0)(0) match { + case d: Double => predictImpl(input.asInstanceOf[IndexedSeq[Array[Double]]]) + case _ => predictImpl(input.asInstanceOf[IndexedSeq[Array[Float]]]) + } + + result.asInstanceOf[IndexedSeq[Array[T]]] + } + + private def predictImpl[B, A <: MX_PRIMITIVE_TYPE] + (input: IndexedSeq[Array[B]])(implicit ev: B => A) + : IndexedSeq[Array[B]] = { + var inputND: ListBuffer[NDArray] = ListBuffer.empty[NDArray] for((i, d) <- input.zip(inputDescriptors)) { val shape = d.shape.toVector.patch(from = batchIndex, patch = Vector(1), replaced = 1) - - inputND += mxNetHandler.execute(NDArray.array(i, Shape(shape))) + if (d.dtype == DType.Float64) { + inputND += mxNetHandler.execute(NDArray.array(i.asInstanceOf[Array[Double]], Shape(shape))) + } + else { + inputND += mxNetHandler.execute(NDArray.array(i.asInstanceOf[Array[Float]], Shape(shape))) + } } // rebind with batchsize 1 @@ -158,7 +179,8 @@ class Predictor(modelPathPrefix: String, val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter( inputND.toIndexedSeq, dataBatchSize = 1))) - val result = resultND.map((f : NDArray) => f.toArray) + val result = + resultND.map((f : NDArray) => if (f.dtype == DType.Float64) f.toFloat64Array else f.toArray) mxNetHandler.execute(inputND.foreach(_.dispose)) mxNetHandler.execute(resultND.foreach(_.dispose)) @@ -168,9 +190,11 @@ class Predictor(modelPathPrefix: String, mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false, forceRebind = true)) } - result + result.asInstanceOf[IndexedSeq[Array[B]]] } + + /** * Predict using NDArray as input * This method is useful when the input is a batch of data diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala index 0466693be9bc..146fe93105e4 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala @@ -72,6 +72,30 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor) predictor.predict(input).toArray } + /** + * Takes input as Array of one dimensional arrays and creates the NDArray needed for inference + * The array will be reshaped based on the input descriptors. Example of calling in Java: + * + *
+ * {@code + * double tmp[][] = new double[1][224]; + * for (int x = 0; x < 1; x++) + * for (int y = 0; y < 224; y++) + * tmp[x][y] = (int)(Math.random()*10); + * predictor.predict(tmp); + * } + *+ * + * @param input: An Array of a one-dimensional array. + An extra Array is needed for when the model has more than one input. + * @return Indexed sequence array of outputs + */ + + def predict(input: Array[Array[Double]]): + Array[Array[Double]] = { + predictor.predict(input).toArray + } + /** * Takes input as List of one dimensional arrays and creates the NDArray needed for inference * The array will be reshaped based on the input descriptors. diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala index b28aeba1deed..d9ccec468791 100644 --- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala +++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala @@ -22,7 +22,7 @@ import java.nio.file.{Files, Paths} import java.util import org.apache.mxnet.module.Module -import org.apache.mxnet.{Context, DataDesc, NDArray, Shape} +import org.apache.mxnet.{Context, DType, DataDesc, NDArray, Shape} import org.mockito.Matchers._ import org.mockito.Mockito import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -127,6 +127,29 @@ class ClassifierSuite extends FunSuite with BeforeAndAfterAll { } + test("ClassifierSuite-flatFloat64Array-topK") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputData = Array.fill[Double](12)(1d) + + val predictResult : IndexedSeq[Array[Double]] = + IndexedSeq[Array[Double]](Array(.98d, 0.97d, 0.96d, 0.99d)) + + val testClassifier = new MyClassifier(modelPath, inputDescriptor) + + Mockito.doReturn(predictResult).when(testClassifier.predictor) + .predict(any(classOf[IndexedSeq[Array[Double]]])) + + val result: IndexedSeq[(String, Double)] = testClassifier. + classify(IndexedSeq(inputData), topK = Some(10)) + + assert((result(0)_2).getClass == 1d.getClass) + + assertResult(predictResult(0).sortBy(-_)) { + result.map(_._2).toArray + } + + } + test("ClassifierSuite-flatArrayInput") { val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) val inputData = Array.fill[Float](12)(1) @@ -147,6 +170,28 @@ class ClassifierSuite extends FunSuite with BeforeAndAfterAll { } } + test("ClassifierSuite-flatArrayFloat64Input") { + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) + val inputData = Array.fill[Double](12)(1d) + + val predictResult : IndexedSeq[Array[Double]] = + IndexedSeq[Array[Double]](Array(.98d, 0.97d, 0.96d, 0.99d)) + + val testClassifier = new MyClassifier(modelPath, inputDescriptor) + + Mockito.doReturn(predictResult).when(testClassifier.predictor) + .predict(any(classOf[IndexedSeq[Array[Double]]])) + + val result: IndexedSeq[(String, Double)] = testClassifier. + classify(IndexedSeq(inputData)) + + assert((result(0)_2).getClass == 1d.getClass) + + assertResult(predictResult(0)) { + result.map(_._2).toArray + } + } + test("ClassifierSuite-NDArray1InputWithoutTopK") { val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2))) val inputDataShape = Shape(1, 3, 2, 2) diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala index 1c291e1e7b3c..5198c4a1f309 100644 --- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala +++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala @@ -68,6 +68,10 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { val result = ImageClassifier.bufferedImageToPixels(image2, Shape(3, 2, 2)) assert(result.shape == inputDescriptor(0).shape.drop(1)) + assert(result.dtype == DType.Float32) + + val resultFloat64 = ImageClassifier.bufferedImageToPixels(image2, Shape(3, 2, 2), DType.Float64) + assert(resultFloat64.dtype == DType.Float64) } test("ImageClassifierSuite-testWithInputImage") { @@ -106,8 +110,10 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { predictResult(i).map(_._2).toArray } } + } + test("ImageClassifierSuite-testWithInputBatchImage") { val dType = DType.Float32 val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512), @@ -152,4 +158,5 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll { } } } + } diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala index 509ffb35db8d..9afbc9b3d4a8 100644 --- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala +++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala @@ -19,7 +19,7 @@ package org.apache.mxnet.infer import org.apache.mxnet.io.NDArrayIter import org.apache.mxnet.module.{BaseModule, Module} -import org.apache.mxnet.{DataDesc, Layout, NDArray, Shape} +import org.apache.mxnet._ import org.mockito.Matchers._ import org.mockito.Mockito import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -91,6 +91,36 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll { , any[Option[BaseModule]], any[String]) } + test("PredictorSuite-testWithFlatFloat64Arrays") { + + val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2), + layout = Layout.NCHW, dtype = DType.Float64)) + val inputData = Array.fill[Double](12)(1d) + + // this will disposed at the end of the predict call on Predictor. + val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2), dtype = DType.Float64)) + + val testPredictor = new MyPredictor("xyz", inputDescriptor) + + Mockito.doReturn(predictResult).when(testPredictor.mockModule) + .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean]) + + val testFun = testPredictor.predict(IndexedSeq(inputData)) + + assert(testFun.size == 1, "output size should be 1 ") + + assert(testFun(0)(0).getClass == 1d.getClass) + + assert(Array.fill[Double](12)(1d).mkString == testFun(0).mkString) + + // Verify that the module was bound with batch size 1 and rebound back to the original + // input descriptor. the number of times is twice here because loadModule overrides the + // initial bind. + Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]], + any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean] + , any[Option[BaseModule]], any[String]) + } + test("PredictorSuite-testWithNDArray") { val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2), layout = Layout.NCHW)) diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc index d684c6d13564..ea6e9c8f5ba4 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc @@ -424,6 +424,15 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromCPU return ret; } +JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFromCPU + (JNIEnv *env, jobject obj, jlong arrayPtr, jdoubleArray sourceArr, jint arrSize) { + jdouble *sourcePtr = env->GetDoubleArrayElements(sourceArr, NULL); + int ret = MXNDArraySyncCopyFromCPU(reinterpret_cast