-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-1260] Float64 DType computation support in Scala/Java #13678
Conversation
@mxnet-label-bot Add [pr-work-in-progress, Scala] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add some Scala test cases
@lanking520 I have already added the Scala tests in NDArraySuite.scala class. It's in this commit : 5529f94 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks to your great contribution! Overall look clean and tidy!
@mxnet-label-bot remove [pr-work-in-progress] |
@mxnet-label-bot Add [pr-awaiting-review] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
Show resolved
Hide resolved
…ature, and public methods calling the Impls
….scala for creating a new DataIterator
This reverts commit 9d7c3ce.
…ons in NDArrays. (These ops were already supported in the previous versions)
…ded integration tests for it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@piyushghai please Satisfy the lint god (╯‵□′)╯︵┻━┻ |
@piyushghai please Satisfy the lint god (‵□′) |
* upstream/master: (109 commits) Code modification for testcases of various network models in directory example (apache#12498) [CI] Prevent timeouts when rebuilding containers with docker. (apache#13818) fix Makefile for rpkg (apache#13590) change to compile time (apache#13835) Disabled flaky test (apache#13758) Improve license_header tool by only traversing files under revision c… (apache#13803) Removes unneeded nvidia driver ppa installation (apache#13814) Add Local test stage and option to jump directly to menu item from commandline (apache#13809) Remove MXNET_STORAGE_FALLBACK_LOG_VERBOSE from test_autograd.py (apache#13830) Fix scala doc build break for v1.3.1 (apache#13820) [MXNET-1263] Unit Tests for Java Predictor and Object Detector APIs (apache#13794) [MXNET-1260] Float64 DType computation support in Scala/Java (apache#13678) onnx export ops (apache#13821) [MXNET-880] ONNX export: Random uniform, Random normal, MaxRoiPool (apache#13676) fix minor indentation (apache#13827) Fixing a symlink issue with R install (apache#13708) remove useless code (apache#13777) ONNX ops: norm exported and lpnormalization imported (apache#13806) Add new Maven build for Scala package (apache#13819) Dockerfiles for Publish Testing (apache#13707) ...
…pache#13678)" This reverts commit ed7ca26.
…13678) * Added Float64 as a supported datatype in NDArray * Added unit tests for Float64 in NDArray * Fix for failing Clojure unit tests * Added Float and Double as MX_PRIMITIVES for computation in Scala * Trying out second approach --> Private Impl methods with generic signature, and public methods calling the Impls * Fixed errors in *= method * Added Float64 in IO.scala and DataIter.scala * Added another testcase for IO.DataDesc creation * Fixed failing CI * Added Float64 in Predictor class * Added Float64 in Classifier class * Added Double as a possible return type to : classifyWithNDArray * Added unit tests for Classifier and Predictor.scala classes for Float64/Double * Approach 3 --> Using a trait to mirror Float and Double in Scala * Added comments on MX_PRIMITIVES.scala * Added Float64/Double support for inference in ImageClassifier APIs * Added unary- and compareTo in MX_NUMBER_LIKE * Renamed MX_NUMBER_LIKE to MX_PRIMITIVE_TYPE * Fixed linting issue * Now specifying dType from the available data in copyTo and MXDataIter.scala for creating a new DataIterator * Add primitives support handling to the generator for proper conversion * Reduced code duplication in classify method in Classifier.scala * Fix infer package for new signatures and address some bugs * Removed code duplication in getPixelsArray * remove debugging * Changed classifyWithNDArray method in Classifier.scala * Removed code duplication in predictImpl * Satisfying lint god _/\_ * Fixed failing PredictorSuite test * Renamed MX_FLOAT to Camel case * Revert "Renamed MX_FLOAT to Camel case" This reverts commit 9d7c3ce. * Added an implicit conversion from int--> float to support int operations in NDArrays. (These ops were already supported in the previous versions) * Added Float64 as a training option to ImClassification Suite. Also added integration tests for it * Satisfy Lint God _/\_ * Added Float64 support in Java NDArray * Added Float64 support in Java's Predictor API * Added yours truly to the Contributors list * Added method comments on Predictor.predict with Array[Double] as a possible input * Added method comments explaining what MX_PRIMITIVE_TYPE is * Fixed errors cause by rebasing with master * Added licences to the files
Description
This PR introduces Float64/Double data type support in NDArrays in Scala. Currently we only allow precision upto Float32 in Scala as a result of which there are issues when one tries to load a model trained using float64 (in another language binding).
This also fixes two long standing issues : fixes #11315 & fixes #10338
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Comments