Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1260] Float64 DType computation support in Scala/Java #13678

Merged
merged 41 commits into from
Jan 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
4aabf1e
Added Float64 as a supported datatype in NDArray
piyushghai Dec 18, 2018
002720c
Added unit tests for Float64 in NDArray
piyushghai Dec 18, 2018
96f2849
Fix for failing Clojure unit tests
piyushghai Dec 19, 2018
8754e41
Added Float and Double as MX_PRIMITIVES for computation in Scala
piyushghai Dec 19, 2018
d92840a
Trying out second approach --> Private Impl methods with generic sign…
piyushghai Dec 19, 2018
c21ac15
Fixed errors in *= method
piyushghai Dec 20, 2018
78dc31c
Added Float64 in IO.scala and DataIter.scala
piyushghai Dec 20, 2018
4cdb77c
Added another testcase for IO.DataDesc creation
piyushghai Dec 20, 2018
2d275fb
Fixed failing CI
piyushghai Dec 20, 2018
9bc0a50
Added Float64 in Predictor class
piyushghai Dec 21, 2018
5ae1040
Added Float64 in Classifier class
piyushghai Dec 21, 2018
d455be5
Added Double as a possible return type to : classifyWithNDArray
piyushghai Dec 21, 2018
a5ff826
Added unit tests for Classifier and Predictor.scala classes for Float…
piyushghai Dec 21, 2018
c598da0
Approach 3 --> Using a trait to mirror Float and Double in Scala
piyushghai Dec 22, 2018
b5bc531
Added comments on MX_PRIMITIVES.scala
piyushghai Dec 22, 2018
1b43071
Added Float64/Double support for inference in ImageClassifier APIs
piyushghai Dec 26, 2018
98e41f0
Added unary- and compareTo in MX_NUMBER_LIKE
piyushghai Dec 27, 2018
49c9af5
Renamed MX_NUMBER_LIKE to MX_PRIMITIVE_TYPE
piyushghai Dec 27, 2018
9a83b2a
Fixed linting issue
piyushghai Dec 27, 2018
c9a0261
Now specifying dType from the available data in copyTo and MXDataIter…
piyushghai Dec 27, 2018
07bd464
Add primitives support handling to the generator for proper conversion
gigasquid Dec 27, 2018
c9fc37a
Reduced code duplication in classify method in Classifier.scala
piyushghai Dec 28, 2018
32b8d99
Fix infer package for new signatures and address some bugs
gigasquid Dec 28, 2018
b450355
Removed code duplication in getPixelsArray
piyushghai Dec 28, 2018
8167005
Changed classifyWithNDArray method in Classifier.scala
piyushghai Dec 28, 2018
422f9de
Removed code duplication in predictImpl
piyushghai Dec 28, 2018
fdc8a1e
Satisfying lint god _/\_
piyushghai Dec 28, 2018
e5a5bfe
remove debugging
gigasquid Dec 28, 2018
db77274
Fixed failing PredictorSuite test
piyushghai Dec 28, 2018
72a734b
Renamed MX_FLOAT to Camel case
piyushghai Jan 2, 2019
c67e2dc
Revert "Renamed MX_FLOAT to Camel case"
piyushghai Jan 2, 2019
50abf88
Added an implicit conversion from int--> float to support int operati…
piyushghai Jan 2, 2019
3b52b50
Added Float64 as a training option to ImClassification Suite. Also ad…
piyushghai Jan 3, 2019
ddc8e78
Satisfy Lint God _/\_
piyushghai Jan 3, 2019
3d85167
Added Float64 support in Java NDArray
piyushghai Jan 4, 2019
29bd834
Added Float64 support in Java's Predictor API
piyushghai Jan 4, 2019
b7f851e
Added yours truly to the Contributors list
piyushghai Jan 4, 2019
e329ee4
Added method comments on Predictor.predict with Array[Double] as a po…
piyushghai Jan 4, 2019
69c35b3
Added method comments explaining what MX_PRIMITIVE_TYPE is
piyushghai Jan 4, 2019
fa9fb45
Fixed errors cause by rebasing with master
piyushghai Jan 10, 2019
e562e94
Added licences to the files
piyushghai Jan 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ List of Contributors
* [Yuxi Hu](https://github.com/yuxihu)
* [Harsh Patel](https://github.com/harshp8l)
* [Xiao Wang](https://github.com/BeyonderXX)
* [Piyush Ghai](https://github.com/piyushghai)

Label Bot
---------
Expand Down
242 changes: 117 additions & 125 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
(ns org.apache.clojure-mxnet.infer
(:refer-clojure :exclude [type])
(:require [org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.shape :as shape]
[org.apache.clojure-mxnet.util :as util]
Expand Down Expand Up @@ -62,10 +63,12 @@
(defprotocol AImageClassifier
(classify-image
[wrapped-image-classifier image]
[wrapped-image-classifier image topk])
[wrapped-image-classifier image topk]
[wrapped-image-classifier image topk dtype])
(classify-image-batch
[wrapped-image-classifier images]
[wrapped-image-classifier images topk]))
[wrapped-image-classifier images topk]
[wrapped-image-classifier images topk dtype]))

(defprotocol AObjectDetector
(detect-objects
Expand All @@ -80,7 +83,8 @@

(extend-protocol APredictor
WrappedPredictor
(predict [wrapped-predictor inputs]
(predict
[wrapped-predictor inputs]
(util/validate! ::wrapped-predictor wrapped-predictor
"Invalid predictor")
(util/validate! ::vec-of-float-arrays inputs
Expand All @@ -101,62 +105,50 @@

(extend-protocol AClassifier
WrappedClassifier
(classify [wrapped-classifier inputs]
(util/validate! ::wrapped-classifier wrapped-classifier
"Invalid classifier")
(util/validate! ::vec-of-float-arrays inputs
"Invalid inputs")
(classify wrapped-classifier inputs nil))
(classify [wrapped-classifier inputs topk]
(util/validate! ::wrapped-classifier wrapped-classifier
"Invalid classifier")
(util/validate! ::vec-of-float-arrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classify (:classifier wrapped-classifier)
(util/vec->indexed-seq inputs)
(util/->int-option topk))))
(classify-with-ndarray [wrapped-classifier inputs]
(util/validate! ::wrapped-classifier wrapped-classifier
"Invalid classifier")
(util/validate! ::vec-of-ndarrays inputs
"Invalid inputs")
(classify-with-ndarray wrapped-classifier inputs nil))
(classify-with-ndarray [wrapped-classifier inputs topk]
(util/validate! ::wrapped-classifier wrapped-classifier
"Invalid classifier")
(util/validate! ::vec-of-ndarrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classifyWithNDArray (:classifier wrapped-classifier)
(util/vec->indexed-seq inputs)
(util/->int-option topk))))
(classify
([wrapped-classifier inputs]
(classify wrapped-classifier inputs nil))
([wrapped-classifier inputs topk]
(util/validate! ::wrapped-classifier wrapped-classifier
"Invalid classifier")
(util/validate! ::vec-of-float-arrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classify (:classifier wrapped-classifier)
(util/vec->indexed-seq inputs)
(util/->int-option topk)))))
(classify-with-ndarray
([wrapped-classifier inputs]
(classify-with-ndarray wrapped-classifier inputs nil))
([wrapped-classifier inputs topk]
(util/validate! ::wrapped-classifier wrapped-classifier
"Invalid classifier")
(util/validate! ::vec-of-ndarrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classifyWithNDArray (:classifier wrapped-classifier)
(util/vec->indexed-seq inputs)
(util/->int-option topk)))))
WrappedImageClassifier
(classify [wrapped-image-classifier inputs]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::vec-of-float-arrays inputs
"Invalid inputs")
(classify wrapped-image-classifier inputs nil))
(classify [wrapped-image-classifier inputs topk]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::vec-of-float-arrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classify (:image-classifier wrapped-image-classifier)
(util/vec->indexed-seq inputs)
(util/->int-option topk))))
(classify-with-ndarray [wrapped-image-classifier inputs]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::vec-of-ndarrays inputs
"Invalid inputs")
(classify-with-ndarray wrapped-image-classifier inputs nil))
(classify-with-ndarray [wrapped-image-classifier inputs topk]
(classify
([wrapped-image-classifier inputs]
(classify wrapped-image-classifier inputs nil))
([wrapped-image-classifier inputs topk]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::vec-of-float-arrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classify (:image-classifier wrapped-image-classifier)
(util/vec->indexed-seq inputs)
(util/->int-option topk)))))
(classify-with-ndarray
([wrapped-image-classifier inputs]
(classify-with-ndarray wrapped-image-classifier inputs nil))
([wrapped-image-classifier inputs topk]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::vec-of-ndarrays inputs
Expand All @@ -165,83 +157,83 @@
(util/coerce-return-recursive
(.classifyWithNDArray (:image-classifier wrapped-image-classifier)
(util/vec->indexed-seq inputs)
(util/->int-option topk)))))
(util/->int-option topk))))))

(s/def ::image #(instance? BufferedImage %))
(s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64})

(extend-protocol AImageClassifier
WrappedImageClassifier
(classify-image [wrapped-image-classifier image]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::image image "Invalid image")
(classify-image wrapped-image-classifier image nil))
(classify-image [wrapped-image-classifier image topk]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::image image "Invalid image")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classifyImage (:image-classifier wrapped-image-classifier)
image
(util/->int-option topk))))
(classify-image-batch [wrapped-image-classifier images]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(classify-image-batch wrapped-image-classifier images nil))
(classify-image-batch [wrapped-image-classifier images topk]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classifyImageBatch (:image-classifier wrapped-image-classifier)
images
(util/->int-option topk)))))
(classify-image
([wrapped-image-classifier image]
(classify-image wrapped-image-classifier image nil dtype/FLOAT32))
([wrapped-image-classifier image topk]
(classify-image wrapped-image-classifier image topk dtype/FLOAT32))
([wrapped-image-classifier image topk dtype]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::image image "Invalid image")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/validate! ::dtype dtype "Invalid dtype")
(util/coerce-return-recursive
(.classifyImage (:image-classifier wrapped-image-classifier)
image
(util/->int-option topk)
dtype))))
(classify-image-batch
([wrapped-image-classifier images]
(classify-image-batch wrapped-image-classifier images nil dtype/FLOAT32))
([wrapped-image-classifier images topk]
(classify-image-batch wrapped-image-classifier images topk dtype/FLOAT32))
([wrapped-image-classifier images topk dtype]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/validate! ::dtype dtype "Invalid dtype")
(util/coerce-return-recursive
(.classifyImageBatch (:image-classifier wrapped-image-classifier)
images
(util/->int-option topk)
dtype)))))

(extend-protocol AObjectDetector
WrappedObjectDetector
(detect-objects [wrapped-detector image]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::image image "Invalid image")
(detect-objects wrapped-detector image nil))
(detect-objects [wrapped-detector image topk]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::image image "Invalid image")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.imageObjectDetect (:object-detector wrapped-detector)
image
(util/->int-option topk))))
(detect-objects-batch [wrapped-detector images]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(detect-objects-batch wrapped-detector images nil))
(detect-objects-batch [wrapped-detector images topk]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.imageBatchObjectDetect (:object-detector wrapped-detector)
images
(util/->int-option topk))))
(detect-objects-with-ndarrays [wrapped-detector input-arrays]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::vec-of-ndarrays input-arrays
"Invalid inputs")
(detect-objects-with-ndarrays wrapped-detector input-arrays nil))
(detect-objects-with-ndarrays [wrapped-detector input-arrays topk]
(detect-objects
([wrapped-detector image]
(detect-objects wrapped-detector image nil))
([wrapped-detector image topk]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::vec-of-ndarrays input-arrays
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.objectDetectWithNDArray (:object-detector wrapped-detector)
(util/vec->indexed-seq input-arrays)
(util/validate! ::image image "Invalid image")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.imageObjectDetect (:object-detector wrapped-detector)
image
(util/->int-option topk)))))
(detect-objects-batch
([wrapped-detector images]
(detect-objects-batch wrapped-detector images nil))
([wrapped-detector images topk]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.imageBatchObjectDetect (:object-detector wrapped-detector)
images
(util/->int-option topk)))))
(detect-objects-with-ndarrays
([wrapped-detector input-arrays]
(detect-objects-with-ndarrays wrapped-detector input-arrays nil))
([wrapped-detector input-arrays topk]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::vec-of-ndarrays input-arrays
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.objectDetectWithNDArray (:object-detector wrapped-detector)
(util/vec->indexed-seq input-arrays)
(util/->int-option topk))))))

(defprotocol AInferenceFactory
(create-predictor [factory] [factory opts])
Expand Down Expand Up @@ -335,7 +327,7 @@
[image input-shape-vec]
(util/validate! ::image image "Invalid image")
(util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector")
(ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec)))
(ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec) dtype/FLOAT32))

(s/def ::image-path string?)
(s/def ::image-paths (s/coll-of ::image-path))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns org.apache.clojure-mxnet.primitives
(:import (org.apache.mxnet MX_PRIMITIVES$MX_FLOAT MX_PRIMITIVES$MX_Double
MX_PRIMITIVES$MX_PRIMITIVE_TYPE)))


;;; Defines customer mx primitives that can be used for mathematical computations
;;; in NDArrays to control precision. Currently Float and Double are supported

;;; For purposes of automatic conversion in ndarray functions, doubles are default
;; to specify using floats you must use a Float

(defn mx-float
"Creates a MXNet float primitive"
[num]
(new MX_PRIMITIVES$MX_FLOAT num))

(defn mx-double
"Creates a MXNet double primitive"
[num]
(new MX_PRIMITIVES$MX_Double num))

(defn ->num
"Returns the underlying number value"
[primitive]
(.data primitive))

(defn primitive? [x]
(instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE x))

Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
(:require [clojure.spec.alpha :as s]
[t6.from-scala.core :refer [$ $$] :as $]
[clojure.string :as string]
[org.apache.clojure-mxnet.primitives :as primitives]
[org.apache.clojure-mxnet.shape :as mx-shape])
(:import (org.apache.mxnet NDArray)
(scala Product Tuple2 Tuple3)
Expand All @@ -36,7 +37,8 @@
"byte<>" "byte-array"
"java.lang.String<>" "vec-or-strings"
"org.apache.mxnet.NDArray" "ndarray"
"org.apache.mxnet.Symbol" "sym"})
"org.apache.mxnet.Symbol" "sym"
"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE" "double-or-float"})

(def symbol-param-coerce {"java.lang.String" "sym-name"
"float" "num"
Expand Down Expand Up @@ -144,6 +146,8 @@
(and (get targets "int<>") (vector? param)) (int-array param)
(and (get targets "float<>") (vector? param)) (float-array param)
(and (get targets "java.lang.String<>") (vector? param)) (into-array param)
(and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (instance? Float param)) (primitives/mx-float param)
(and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (number? param)) (primitives/mx-double param)
:else param))

(defn nil-or-coerce-param [param targets]
Expand Down Expand Up @@ -177,6 +181,7 @@
(instance? Map return-val) (scala-map->map return-val)
(instance? Tuple2 return-val) (tuple->vec return-val)
(instance? Tuple3 return-val) (tuple->vec return-val)
(primitives/primitive? return-val) (primitives/->num return-val)
:else return-val))

(defn coerce-return-recursive [return-val]
Expand Down
Loading