From 4dd5f87ceafa2c5b41aa9a26d178c23327742c59 Mon Sep 17 00:00:00 2001 From: Dang Trung Kien Date: Thu, 14 Mar 2019 02:07:41 +0800 Subject: [PATCH] Fix relative difference scala (#14417) * Fix relative difference scala * Increase number of cases for scala arange test * Add cases where arange produces NDArray of [0] * Remote whitespace --- .../core/src/test/scala/org/apache/mxnet/CheckUtils.scala | 4 ++-- .../src/test/scala/org/apache/mxnet/NDArraySuite.scala | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/CheckUtils.scala b/scala-package/core/src/test/scala/org/apache/mxnet/CheckUtils.scala index 1ddb292dc3d2..7602b53edc9e 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/CheckUtils.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/CheckUtils.scala @@ -21,13 +21,13 @@ object CheckUtils { def reldiff(a: NDArray, b: NDArray): Float = { val diff = NDArray.sum(NDArray.abs(a - b)).toScalar val norm = NDArray.sum(NDArray.abs(a)).toScalar - diff / norm + if (diff < Float.MinPositiveValue) diff else diff / norm } def reldiff(a: Array[Float], b: Array[Float]): Float = { val diff = (a zip b).map { case (aElem, bElem) => Math.abs(aElem - bElem) }.sum val norm: Float = a.reduce(Math.abs(_) + Math.abs(_)) - diff / norm + if (diff < Float.MinPositiveValue) diff else diff / norm } } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index 72a5974801a1..206094c15958 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -355,6 +355,13 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { val result3 = 0f to stop by 1f val range3 = NDArray.arange(stop) assert(CheckUtils.reldiff(result3.toArray, range3.toArray) <= 1e-4f) + + val stop4 = Math.abs(stop) + val step4 = stop4 + Math.abs(scala.util.Random.nextFloat()) + val result4 = (0.0 until stop4.toDouble by step4.toDouble) + .flatMap(x => Array.fill[Float](repeat)(x.toFloat)) + val range4 = NDArray.arange(stop4, step = step4, repeat = repeat) + assert(CheckUtils.reldiff(result4.toArray, range4.toArray) <= 1e-4f) } }