Skip to content

Commit

Permalink
Fix relative difference scala (apache#14417)
Browse files Browse the repository at this point in the history
* Fix relative difference scala

* Increase number of cases for scala arange test

* Add cases where arange produces NDArray of [0]

* Remote whitespace
  • Loading branch information
kiendang authored and vdantu committed Mar 31, 2019
1 parent 6ba4759 commit 4dd5f87
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ object CheckUtils {
def reldiff(a: NDArray, b: NDArray): Float = {
val diff = NDArray.sum(NDArray.abs(a - b)).toScalar
val norm = NDArray.sum(NDArray.abs(a)).toScalar
diff / norm
if (diff < Float.MinPositiveValue) diff else diff / norm
}

def reldiff(a: Array[Float], b: Array[Float]): Float = {
val diff =
(a zip b).map { case (aElem, bElem) => Math.abs(aElem - bElem) }.sum
val norm: Float = a.reduce(Math.abs(_) + Math.abs(_))
diff / norm
if (diff < Float.MinPositiveValue) diff else diff / norm
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,13 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
val result3 = 0f to stop by 1f
val range3 = NDArray.arange(stop)
assert(CheckUtils.reldiff(result3.toArray, range3.toArray) <= 1e-4f)

val stop4 = Math.abs(stop)
val step4 = stop4 + Math.abs(scala.util.Random.nextFloat())
val result4 = (0.0 until stop4.toDouble by step4.toDouble)
.flatMap(x => Array.fill[Float](repeat)(x.toFloat))
val range4 = NDArray.arange(stop4, step = step4, repeat = repeat)
assert(CheckUtils.reldiff(result4.toArray, range4.toArray) <= 1e-4f)
}
}

Expand Down

0 comments on commit 4dd5f87

Please sign in to comment.