Skip to content

Commit

Permalink
[MXNET-1260] Float64 DType computation support in Scala/Java (apache#…
Browse files Browse the repository at this point in the history
…13678)

* Added Float64 as a supported datatype in NDArray

* Added unit tests for Float64 in NDArray

* Fix for failing Clojure unit tests

* Added Float and Double as MX_PRIMITIVES for computation in Scala

* Trying out second approach --> Private Impl methods with generic signature, and public methods calling the Impls

* Fixed errors in *= method

* Added Float64 in IO.scala and DataIter.scala

* Added another testcase for IO.DataDesc creation

* Fixed failing CI

* Added Float64 in Predictor class

* Added Float64 in Classifier class

* Added Double as a possible return type to : classifyWithNDArray

* Added unit tests for Classifier and Predictor.scala classes for Float64/Double

* Approach 3 --> Using a trait to mirror Float and Double in Scala

* Added comments on MX_PRIMITIVES.scala

* Added Float64/Double support for inference in ImageClassifier APIs

* Added unary- and compareTo in MX_NUMBER_LIKE

* Renamed MX_NUMBER_LIKE to MX_PRIMITIVE_TYPE

* Fixed linting issue

* Now specifying dType from the available data in copyTo and MXDataIter.scala for creating a new DataIterator

* Add primitives support handling to the generator for proper conversion

* Reduced code duplication in classify method in Classifier.scala

* Fix infer package for new signatures and address some bugs

* Removed code duplication in getPixelsArray

* remove debugging

* Changed classifyWithNDArray method in Classifier.scala

* Removed code duplication in predictImpl

* Satisfying lint god _/\_

* Fixed failing PredictorSuite test

* Renamed MX_FLOAT to Camel case

* Revert "Renamed MX_FLOAT to Camel case"

This reverts commit 9d7c3ce.

* Added an implicit conversion from int--> float to support int operations in NDArrays. (These ops were already supported in the previous versions)

* Added Float64 as a training option to ImClassification Suite. Also added integration tests for it

* Satisfy Lint God _/\_

* Added Float64 support in Java NDArray

* Added Float64 support in Java's Predictor API

* Added yours truly to the Contributors list

* Added method comments on Predictor.predict with Array[Double] as a possible input

* Added method comments explaining what MX_PRIMITIVE_TYPE is

*  Fixed errors cause by rebasing with master

* Added licences to the files
  • Loading branch information
piyushghai authored and lanking520 committed Jan 10, 2019
1 parent d973ed4 commit ed7ca26
Show file tree
Hide file tree
Showing 35 changed files with 1,251 additions and 294 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ List of Contributors
* [Yuxi Hu](https://github.com/yuxihu)
* [Harsh Patel](https://github.com/harshp8l)
* [Xiao Wang](https://github.com/BeyonderXX)
* [Piyush Ghai](https://github.com/piyushghai)

Label Bot
---------
Expand Down
242 changes: 117 additions & 125 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
(ns org.apache.clojure-mxnet.infer
(:refer-clojure :exclude [type])
(:require [org.apache.clojure-mxnet.context :as context]
[org.apache.clojure-mxnet.dtype :as dtype]
[org.apache.clojure-mxnet.io :as mx-io]
[org.apache.clojure-mxnet.shape :as shape]
[org.apache.clojure-mxnet.util :as util]
Expand Down Expand Up @@ -62,10 +63,12 @@
(defprotocol AImageClassifier
(classify-image
[wrapped-image-classifier image]
[wrapped-image-classifier image topk])
[wrapped-image-classifier image topk]
[wrapped-image-classifier image topk dtype])
(classify-image-batch
[wrapped-image-classifier images]
[wrapped-image-classifier images topk]))
[wrapped-image-classifier images topk]
[wrapped-image-classifier images topk dtype]))

(defprotocol AObjectDetector
(detect-objects
Expand All @@ -80,7 +83,8 @@

(extend-protocol APredictor
WrappedPredictor
(predict [wrapped-predictor inputs]
(predict
[wrapped-predictor inputs]
(util/validate! ::wrapped-predictor wrapped-predictor
"Invalid predictor")
(util/validate! ::vec-of-float-arrays inputs
Expand All @@ -101,62 +105,50 @@

(extend-protocol AClassifier
WrappedClassifier
(classify [wrapped-classifier inputs]
(util/validate! ::wrapped-classifier wrapped-classifier
"Invalid classifier")
(util/validate! ::vec-of-float-arrays inputs
"Invalid inputs")
(classify wrapped-classifier inputs nil))
(classify [wrapped-classifier inputs topk]
(util/validate! ::wrapped-classifier wrapped-classifier
"Invalid classifier")
(util/validate! ::vec-of-float-arrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classify (:classifier wrapped-classifier)
(util/vec->indexed-seq inputs)
(util/->int-option topk))))
(classify-with-ndarray [wrapped-classifier inputs]
(util/validate! ::wrapped-classifier wrapped-classifier
"Invalid classifier")
(util/validate! ::vec-of-ndarrays inputs
"Invalid inputs")
(classify-with-ndarray wrapped-classifier inputs nil))
(classify-with-ndarray [wrapped-classifier inputs topk]
(util/validate! ::wrapped-classifier wrapped-classifier
"Invalid classifier")
(util/validate! ::vec-of-ndarrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classifyWithNDArray (:classifier wrapped-classifier)
(util/vec->indexed-seq inputs)
(util/->int-option topk))))
(classify
([wrapped-classifier inputs]
(classify wrapped-classifier inputs nil))
([wrapped-classifier inputs topk]
(util/validate! ::wrapped-classifier wrapped-classifier
"Invalid classifier")
(util/validate! ::vec-of-float-arrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classify (:classifier wrapped-classifier)
(util/vec->indexed-seq inputs)
(util/->int-option topk)))))
(classify-with-ndarray
([wrapped-classifier inputs]
(classify-with-ndarray wrapped-classifier inputs nil))
([wrapped-classifier inputs topk]
(util/validate! ::wrapped-classifier wrapped-classifier
"Invalid classifier")
(util/validate! ::vec-of-ndarrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classifyWithNDArray (:classifier wrapped-classifier)
(util/vec->indexed-seq inputs)
(util/->int-option topk)))))
WrappedImageClassifier
(classify [wrapped-image-classifier inputs]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::vec-of-float-arrays inputs
"Invalid inputs")
(classify wrapped-image-classifier inputs nil))
(classify [wrapped-image-classifier inputs topk]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::vec-of-float-arrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classify (:image-classifier wrapped-image-classifier)
(util/vec->indexed-seq inputs)
(util/->int-option topk))))
(classify-with-ndarray [wrapped-image-classifier inputs]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::vec-of-ndarrays inputs
"Invalid inputs")
(classify-with-ndarray wrapped-image-classifier inputs nil))
(classify-with-ndarray [wrapped-image-classifier inputs topk]
(classify
([wrapped-image-classifier inputs]
(classify wrapped-image-classifier inputs nil))
([wrapped-image-classifier inputs topk]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::vec-of-float-arrays inputs
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classify (:image-classifier wrapped-image-classifier)
(util/vec->indexed-seq inputs)
(util/->int-option topk)))))
(classify-with-ndarray
([wrapped-image-classifier inputs]
(classify-with-ndarray wrapped-image-classifier inputs nil))
([wrapped-image-classifier inputs topk]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::vec-of-ndarrays inputs
Expand All @@ -165,83 +157,83 @@
(util/coerce-return-recursive
(.classifyWithNDArray (:image-classifier wrapped-image-classifier)
(util/vec->indexed-seq inputs)
(util/->int-option topk)))))
(util/->int-option topk))))))

(s/def ::image #(instance? BufferedImage %))
(s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64})

(extend-protocol AImageClassifier
WrappedImageClassifier
(classify-image [wrapped-image-classifier image]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::image image "Invalid image")
(classify-image wrapped-image-classifier image nil))
(classify-image [wrapped-image-classifier image topk]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::image image "Invalid image")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classifyImage (:image-classifier wrapped-image-classifier)
image
(util/->int-option topk))))
(classify-image-batch [wrapped-image-classifier images]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(classify-image-batch wrapped-image-classifier images nil))
(classify-image-batch [wrapped-image-classifier images topk]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.classifyImageBatch (:image-classifier wrapped-image-classifier)
images
(util/->int-option topk)))))
(classify-image
([wrapped-image-classifier image]
(classify-image wrapped-image-classifier image nil dtype/FLOAT32))
([wrapped-image-classifier image topk]
(classify-image wrapped-image-classifier image topk dtype/FLOAT32))
([wrapped-image-classifier image topk dtype]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::image image "Invalid image")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/validate! ::dtype dtype "Invalid dtype")
(util/coerce-return-recursive
(.classifyImage (:image-classifier wrapped-image-classifier)
image
(util/->int-option topk)
dtype))))
(classify-image-batch
([wrapped-image-classifier images]
(classify-image-batch wrapped-image-classifier images nil dtype/FLOAT32))
([wrapped-image-classifier images topk]
(classify-image-batch wrapped-image-classifier images topk dtype/FLOAT32))
([wrapped-image-classifier images topk dtype]
(util/validate! ::wrapped-image-classifier wrapped-image-classifier
"Invalid classifier")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/validate! ::dtype dtype "Invalid dtype")
(util/coerce-return-recursive
(.classifyImageBatch (:image-classifier wrapped-image-classifier)
images
(util/->int-option topk)
dtype)))))

(extend-protocol AObjectDetector
WrappedObjectDetector
(detect-objects [wrapped-detector image]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::image image "Invalid image")
(detect-objects wrapped-detector image nil))
(detect-objects [wrapped-detector image topk]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::image image "Invalid image")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.imageObjectDetect (:object-detector wrapped-detector)
image
(util/->int-option topk))))
(detect-objects-batch [wrapped-detector images]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(detect-objects-batch wrapped-detector images nil))
(detect-objects-batch [wrapped-detector images topk]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.imageBatchObjectDetect (:object-detector wrapped-detector)
images
(util/->int-option topk))))
(detect-objects-with-ndarrays [wrapped-detector input-arrays]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::vec-of-ndarrays input-arrays
"Invalid inputs")
(detect-objects-with-ndarrays wrapped-detector input-arrays nil))
(detect-objects-with-ndarrays [wrapped-detector input-arrays topk]
(detect-objects
([wrapped-detector image]
(detect-objects wrapped-detector image nil))
([wrapped-detector image topk]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::vec-of-ndarrays input-arrays
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.objectDetectWithNDArray (:object-detector wrapped-detector)
(util/vec->indexed-seq input-arrays)
(util/validate! ::image image "Invalid image")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.imageObjectDetect (:object-detector wrapped-detector)
image
(util/->int-option topk)))))
(detect-objects-batch
([wrapped-detector images]
(detect-objects-batch wrapped-detector images nil))
([wrapped-detector images topk]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.imageBatchObjectDetect (:object-detector wrapped-detector)
images
(util/->int-option topk)))))
(detect-objects-with-ndarrays
([wrapped-detector input-arrays]
(detect-objects-with-ndarrays wrapped-detector input-arrays nil))
([wrapped-detector input-arrays topk]
(util/validate! ::wrapped-detector wrapped-detector
"Invalid object detector")
(util/validate! ::vec-of-ndarrays input-arrays
"Invalid inputs")
(util/validate! ::nil-or-int topk "Invalid top-K")
(util/coerce-return-recursive
(.objectDetectWithNDArray (:object-detector wrapped-detector)
(util/vec->indexed-seq input-arrays)
(util/->int-option topk))))))

(defprotocol AInferenceFactory
(create-predictor [factory] [factory opts])
Expand Down Expand Up @@ -335,7 +327,7 @@
[image input-shape-vec]
(util/validate! ::image image "Invalid image")
(util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector")
(ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec)))
(ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec) dtype/FLOAT32))

(s/def ::image-path string?)
(s/def ::image-paths (s/coll-of ::image-path))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns org.apache.clojure-mxnet.primitives
(:import (org.apache.mxnet MX_PRIMITIVES$MX_FLOAT MX_PRIMITIVES$MX_Double
MX_PRIMITIVES$MX_PRIMITIVE_TYPE)))


;;; Defines customer mx primitives that can be used for mathematical computations
;;; in NDArrays to control precision. Currently Float and Double are supported

;;; For purposes of automatic conversion in ndarray functions, doubles are default
;; to specify using floats you must use a Float

(defn mx-float
"Creates a MXNet float primitive"
[num]
(new MX_PRIMITIVES$MX_FLOAT num))

(defn mx-double
"Creates a MXNet double primitive"
[num]
(new MX_PRIMITIVES$MX_Double num))

(defn ->num
"Returns the underlying number value"
[primitive]
(.data primitive))

(defn primitive? [x]
(instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE x))

Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
(:require [clojure.spec.alpha :as s]
[t6.from-scala.core :refer [$ $$] :as $]
[clojure.string :as string]
[org.apache.clojure-mxnet.primitives :as primitives]
[org.apache.clojure-mxnet.shape :as mx-shape])
(:import (org.apache.mxnet NDArray)
(scala Product Tuple2 Tuple3)
Expand All @@ -36,7 +37,8 @@
"byte<>" "byte-array"
"java.lang.String<>" "vec-or-strings"
"org.apache.mxnet.NDArray" "ndarray"
"org.apache.mxnet.Symbol" "sym"})
"org.apache.mxnet.Symbol" "sym"
"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE" "double-or-float"})

(def symbol-param-coerce {"java.lang.String" "sym-name"
"float" "num"
Expand Down Expand Up @@ -144,6 +146,8 @@
(and (get targets "int<>") (vector? param)) (int-array param)
(and (get targets "float<>") (vector? param)) (float-array param)
(and (get targets "java.lang.String<>") (vector? param)) (into-array param)
(and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (instance? Float param)) (primitives/mx-float param)
(and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (number? param)) (primitives/mx-double param)
:else param))

(defn nil-or-coerce-param [param targets]
Expand Down Expand Up @@ -177,6 +181,7 @@
(instance? Map return-val) (scala-map->map return-val)
(instance? Tuple2 return-val) (tuple->vec return-val)
(instance? Tuple3 return-val) (tuple->vec return-val)
(primitives/primitive? return-val) (primitives/->num return-val)
:else return-val))

(defn coerce-return-recursive [return-val]
Expand Down
Loading

0 comments on commit ed7ca26

Please sign in to comment.