diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj index 89c0ef820609..1b4b2ea2fbe3 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj @@ -226,12 +226,12 @@ (let [arange (sym/arange-with-inference 0) data (sym/variable "data") added (sym/+ arange data) - result (range 0. 4.) + result (range 0 4) data-tmp (ndarray/zeros [4]) exec (sym/bind added (context/default-context) {"data" data-tmp})] (executor/forward exec) (is (= 0 (count (executor/grad-arrays exec)))) - (is (= result (-> (executor/outputs exec) (first) (ndarray/->vec)))))) + (is (approx= 1e-4 result (-> (executor/outputs exec) (first)))))) (deftest test-scalar-pow (let [data (sym/variable "data") diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj index dcdbea645796..ecd54ca72773 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj @@ -22,6 +22,8 @@ (if (and (number? x) (number? y)) (let [diff (Math/abs (- x y))] (< diff tolerance)) - (reduce (fn [x y] (and x y)) - (map #(approx= tolerance %1 %2) x y)))) + (and + (= (count x) (count y)) + (reduce (fn [x y] (and x y)) + (map #(approx= tolerance %1 %2) x y))))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj index 5551fab435f6..de3480827ba4 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj @@ -21,6 +21,7 @@ [org.apache.clojure-mxnet.util :as util] [org.apache.clojure-mxnet.ndarray :as ndarray] [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.test-util :as test-util] [clojure.spec.alpha :as s]) (:import (org.apache.mxnet Shape NDArrayFuncReturn NDArray) (scala.collection Map Set) @@ -183,3 +184,10 @@ (deftest test-validate (is (nil? (util/validate! string? "foo" "Not a string!"))) (is (thrown-with-msg? Exception #"Not a string!" (util/validate! ::x 1 "Not a string!")))) + +(deftest test-approx= + (let [data1 [1 1 1 1] + data2 [1 1 1 1 9 9 9 9] + data3 [1 1 1 2]] + (is (not (test-util/approx= 1e-9 data1 data2))) + (is (test-util/approx= 2 data1 data3)))) \ No newline at end of file