Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Clojure] Helper function for n-dim vector to ndarray (#14305)
Browse files Browse the repository at this point in the history
* [Clojure] Helper function for n-dim vector to ndarray

* More tests, specs and rename method

* Address comments

* Allow every number type
  • Loading branch information
kedarbellare authored and nswamy committed Apr 5, 2019
1 parent 19f05b0 commit f981f4e
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 0 deletions.
21 changes: 21 additions & 0 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,27 @@
([start stop]
(arange start stop {})))

(defn ->ndarray
"Creates a new NDArray based on the given n-dimenstional vector
of numbers.
`nd-vec`: n-dimensional vector with numbers.
`opts-map` {
`ctx`: Context of the output ndarray, will use default context if unspecified.
}
returns: `ndarray` with the given values and matching the shape of the input vector.
Ex:
(->ndarray [5.0 -4.0])
(->ndarray [5 -4] {:ctx (context/cpu)})
(->ndarray [[1 2 3] [4 5 6]])
(->ndarray [[[1.0] [2.0]]]"
([nd-vec {:keys [ctx]
:or {ctx (mx-context/default-context)}
:as opts}]
(array (vec (clojure.core/flatten nd-vec))
(util/nd-seq-shape nd-vec)
{:ctx ctx}))
([nd-vec] (->ndarray nd-vec {})))

(defn slice
"Return a sliced NDArray that shares memory with current one."
([ndarray i]
Expand Down
20 changes: 20 additions & 0 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,26 @@
(throw (ex-info error-msg
(s/explain-data spec value)))))

(s/def ::non-empty-seq (s/and sequential? not-empty))
(defn to-array-nd
"Converts any N-D sequential structure to an array
with the same dimensions."
[nd-seq]
(validate! ::non-empty-seq nd-seq "Invalid N-D sequence")
(if (sequential? (first nd-seq))
(to-array (mapv to-array-nd nd-seq))
(to-array nd-seq)))

(defn nd-seq-shape
"Computes the shape of a n-dimensional sequential structure"
[nd-seq]
(validate! ::non-empty-seq nd-seq "Invalid N-D sequence")
(loop [s nd-seq
shape [(count s)]]
(if (sequential? (first s))
(recur (first s) (conj shape (count (first s))))
shape)))

(defn map->scala-tuple-seq
"* Convert a map to a scala-Seq of scala-Tubple.
* Should also work if a seq of seq of 2 things passed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@
(is (= [0.0 0.0 0.5 0.5 1.0 1.0 1.5 1.5 2.0 2.0 2.5 2.5 3.0 3.0 3.5 3.5 4.0 4.0 4.5 4.5]
(->vec (ndarray/arange start stop {:step step :repeat repeat}))))))

(deftest test->ndarray
(let [nda1 (ndarray/->ndarray [5.0 -4.0])
nda2 (ndarray/->ndarray [[1 2 3]
[4 5 6]])
nda3 (ndarray/->ndarray [[[7.0] [8.0]]])]
(is (= [5.0 -4.0] (->vec nda1)))
(is (= [2] (mx-shape/->vec (shape nda1))))
(is (= [1.0 2.0 3.0 4.0 5.0 6.0] (->vec nda2)))
(is (= [2 3] (mx-shape/->vec (shape nda2))))
(is (= [7.0 8.0] (->vec nda3)))
(is (= [1 2 1] (mx-shape/->vec (shape nda3))))))

(deftest test-power
(let [nda (ndarray/array [3 5] [2 1])]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,33 @@
(is (= [1 2] (-> (util/convert-tuple [1 2])
(util/tuple->vec)))))

(deftest test-to-array-nd
(let [a1 (util/to-array-nd '(1))
a2 (util/to-array-nd [1.0 2.0])
a3 (util/to-array-nd [[3.0] [4.0]])
a4 (util/to-array-nd [[[5 -5]]])]
(is (= 1 (alength a1)))
(is (= [1] (->> a1 vec)))
(is (= 2 (alength a2)))
(is (= 2.0 (aget a2 1)))
(is (= [1.0 2.0] (->> a2 vec)))
(is (= 2 (alength a3)))
(is (= 1 (alength (aget a3 0))))
(is (= 4.0 (aget a3 1 0)))
(is (= [[3.0] [4.0]] (->> a3 vec (mapv vec))))
(is (= 1 (alength a4)))
(is (= 1 (alength (aget a4 0))))
(is (= 2 (alength (aget a4 0 0))))
(is (= 5 (aget a4 0 0 0)))
(is (= [[[5 -5]]] (->> a4 vec (mapv vec) (mapv #(mapv vec %)))))))

(deftest test-nd-seq-shape
(is (= [1] (util/nd-seq-shape '(5))))
(is (= [2] (util/nd-seq-shape [1.0 2.0])))
(is (= [3] (util/nd-seq-shape [1 1 1])))
(is (= [2 1] (util/nd-seq-shape [[3.0] [4.0]])))
(is (= [1 3 2] (util/nd-seq-shape [[[5 -5] [5 -5] [5 -5]]]))))

(deftest test-coerce-return
(is (= [] (util/coerce-return (ArrayBuffer.))))
(is (= [1 2 3] (util/coerce-return (util/vec->indexed-seq [1 2 3]))))
Expand Down

0 comments on commit f981f4e

Please sign in to comment.