diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj index 151e18bcb482..9caa00d49010 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj @@ -94,6 +94,27 @@ ([start stop] (arange start stop {}))) +(defn ->ndarray + "Creates a new NDArray based on the given n-dimenstional vector + of numbers. + `nd-vec`: n-dimensional vector with numbers. + `opts-map` { + `ctx`: Context of the output ndarray, will use default context if unspecified. + } + returns: `ndarray` with the given values and matching the shape of the input vector. + Ex: + (->ndarray [5.0 -4.0]) + (->ndarray [5 -4] {:ctx (context/cpu)}) + (->ndarray [[1 2 3] [4 5 6]]) + (->ndarray [[[1.0] [2.0]]]" + ([nd-vec {:keys [ctx] + :or {ctx (mx-context/default-context)} + :as opts}] + (array (vec (clojure.core/flatten nd-vec)) + (util/nd-seq-shape nd-vec) + {:ctx ctx})) + ([nd-vec] (->ndarray nd-vec {}))) + (defn slice "Return a sliced NDArray that shares memory with current one." ([ndarray i] diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj index 6b5f50792ead..7eb1426d0d3e 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj @@ -218,6 +218,26 @@ (throw (ex-info error-msg (s/explain-data spec value))))) +(s/def ::non-empty-seq (s/and sequential? not-empty)) +(defn to-array-nd + "Converts any N-D sequential structure to an array + with the same dimensions." + [nd-seq] + (validate! ::non-empty-seq nd-seq "Invalid N-D sequence") + (if (sequential? (first nd-seq)) + (to-array (mapv to-array-nd nd-seq)) + (to-array nd-seq))) + +(defn nd-seq-shape + "Computes the shape of a n-dimensional sequential structure" + [nd-seq] + (validate! ::non-empty-seq nd-seq "Invalid N-D sequence") + (loop [s nd-seq + shape [(count s)]] + (if (sequential? (first s)) + (recur (first s) (conj shape (count (first s)))) + shape))) + (defn map->scala-tuple-seq "* Convert a map to a scala-Seq of scala-Tubple. * Should also work if a seq of seq of 2 things passed. diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj index a9ae2966db89..ee7c16b737f6 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj @@ -146,6 +146,18 @@ (is (= [0.0 0.0 0.5 0.5 1.0 1.0 1.5 1.5 2.0 2.0 2.5 2.5 3.0 3.0 3.5 3.5 4.0 4.0 4.5 4.5] (->vec (ndarray/arange start stop {:step step :repeat repeat})))))) +(deftest test->ndarray + (let [nda1 (ndarray/->ndarray [5.0 -4.0]) + nda2 (ndarray/->ndarray [[1 2 3] + [4 5 6]]) + nda3 (ndarray/->ndarray [[[7.0] [8.0]]])] + (is (= [5.0 -4.0] (->vec nda1))) + (is (= [2] (mx-shape/->vec (shape nda1)))) + (is (= [1.0 2.0 3.0 4.0 5.0 6.0] (->vec nda2))) + (is (= [2 3] (mx-shape/->vec (shape nda2)))) + (is (= [7.0 8.0] (->vec nda3))) + (is (= [1 2 1] (mx-shape/->vec (shape nda3)))))) + (deftest test-power (let [nda (ndarray/array [3 5] [2 1])] diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj index 4ed7d38e690a..15c4859c77a6 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj @@ -163,6 +163,33 @@ (is (= [1 2] (-> (util/convert-tuple [1 2]) (util/tuple->vec))))) +(deftest test-to-array-nd + (let [a1 (util/to-array-nd '(1)) + a2 (util/to-array-nd [1.0 2.0]) + a3 (util/to-array-nd [[3.0] [4.0]]) + a4 (util/to-array-nd [[[5 -5]]])] + (is (= 1 (alength a1))) + (is (= [1] (->> a1 vec))) + (is (= 2 (alength a2))) + (is (= 2.0 (aget a2 1))) + (is (= [1.0 2.0] (->> a2 vec))) + (is (= 2 (alength a3))) + (is (= 1 (alength (aget a3 0)))) + (is (= 4.0 (aget a3 1 0))) + (is (= [[3.0] [4.0]] (->> a3 vec (mapv vec)))) + (is (= 1 (alength a4))) + (is (= 1 (alength (aget a4 0)))) + (is (= 2 (alength (aget a4 0 0)))) + (is (= 5 (aget a4 0 0 0))) + (is (= [[[5 -5]]] (->> a4 vec (mapv vec) (mapv #(mapv vec %))))))) + +(deftest test-nd-seq-shape + (is (= [1] (util/nd-seq-shape '(5)))) + (is (= [2] (util/nd-seq-shape [1.0 2.0]))) + (is (= [3] (util/nd-seq-shape [1 1 1]))) + (is (= [2 1] (util/nd-seq-shape [[3.0] [4.0]]))) + (is (= [1 3 2] (util/nd-seq-shape [[[5 -5] [5 -5] [5 -5]]])))) + (deftest test-coerce-return (is (= [] (util/coerce-return (ArrayBuffer.)))) (is (= [1 2 3] (util/coerce-return (util/vec->indexed-seq [1 2 3]))))