This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MXNET-1260] Float64 DType computation support in Scala/Java (#13678)
* Added Float64 as a supported datatype in NDArray * Added unit tests for Float64 in NDArray * Fix for failing Clojure unit tests * Added Float and Double as MX_PRIMITIVES for computation in Scala * Trying out second approach --> Private Impl methods with generic signature, and public methods calling the Impls * Fixed errors in *= method * Added Float64 in IO.scala and DataIter.scala * Added another testcase for IO.DataDesc creation * Fixed failing CI * Added Float64 in Predictor class * Added Float64 in Classifier class * Added Double as a possible return type to : classifyWithNDArray * Added unit tests for Classifier and Predictor.scala classes for Float64/Double * Approach 3 --> Using a trait to mirror Float and Double in Scala * Added comments on MX_PRIMITIVES.scala * Added Float64/Double support for inference in ImageClassifier APIs * Added unary- and compareTo in MX_NUMBER_LIKE * Renamed MX_NUMBER_LIKE to MX_PRIMITIVE_TYPE * Fixed linting issue * Now specifying dType from the available data in copyTo and MXDataIter.scala for creating a new DataIterator * Add primitives support handling to the generator for proper conversion * Reduced code duplication in classify method in Classifier.scala * Fix infer package for new signatures and address some bugs * Removed code duplication in getPixelsArray * remove debugging * Changed classifyWithNDArray method in Classifier.scala * Removed code duplication in predictImpl * Satisfying lint god _/\_ * Fixed failing PredictorSuite test * Renamed MX_FLOAT to Camel case * Revert "Renamed MX_FLOAT to Camel case" This reverts commit 9d7c3ce. * Added an implicit conversion from int--> float to support int operations in NDArrays. (These ops were already supported in the previous versions) * Added Float64 as a training option to ImClassification Suite. Also added integration tests for it * Satisfy Lint God _/\_ * Added Float64 support in Java NDArray * Added Float64 support in Java's Predictor API * Added yours truly to the Contributors list * Added method comments on Predictor.predict with Array[Double] as a possible input * Added method comments explaining what MX_PRIMITIVE_TYPE is * Fixed errors cause by rebasing with master * Added licences to the files
- Loading branch information
1 parent
d973ed4
commit ed7ca26
Showing
35 changed files
with
1,251 additions
and
294 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 46 additions & 0 deletions
46
contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
;; | ||
;; Licensed to the Apache Software Foundation (ASF) under one or more | ||
;; contributor license agreements. See the NOTICE file distributed with | ||
;; this work for additional information regarding copyright ownership. | ||
;; The ASF licenses this file to You under the Apache License, Version 2.0 | ||
;; (the "License"); you may not use this file except in compliance with | ||
;; the License. You may obtain a copy of the License at | ||
;; | ||
;; http://www.apache.org/licenses/LICENSE-2.0 | ||
;; | ||
;; Unless required by applicable law or agreed to in writing, software | ||
;; distributed under the License is distributed on an "AS IS" BASIS, | ||
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
;; See the License for the specific language governing permissions and | ||
;; limitations under the License. | ||
;; | ||
|
||
(ns org.apache.clojure-mxnet.primitives | ||
(:import (org.apache.mxnet MX_PRIMITIVES$MX_FLOAT MX_PRIMITIVES$MX_Double | ||
MX_PRIMITIVES$MX_PRIMITIVE_TYPE))) | ||
|
||
|
||
;;; Defines customer mx primitives that can be used for mathematical computations | ||
;;; in NDArrays to control precision. Currently Float and Double are supported | ||
|
||
;;; For purposes of automatic conversion in ndarray functions, doubles are default | ||
;; to specify using floats you must use a Float | ||
|
||
(defn mx-float | ||
"Creates a MXNet float primitive" | ||
[num] | ||
(new MX_PRIMITIVES$MX_FLOAT num)) | ||
|
||
(defn mx-double | ||
"Creates a MXNet double primitive" | ||
[num] | ||
(new MX_PRIMITIVES$MX_Double num)) | ||
|
||
(defn ->num | ||
"Returns the underlying number value" | ||
[primitive] | ||
(.data primitive)) | ||
|
||
(defn primitive? [x] | ||
(instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE x)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.